Skip to content
24 changes: 20 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<XeGPU_LayoutAttr>:$layout);
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Expand Down Expand Up @@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Type": $value, "Value": $source,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::LayoutAttr": $layout)>
];

let hasVerifier = 1;
Expand Down Expand Up @@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<XeGPU_LayoutAttr>:$layout);

let extraClassDeclaration = extraBaseClassDeclaration#[{
Type getDestType() {
Expand Down Expand Up @@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint)>,
OpBuilder<(ins "Value": $value, "Value": $dest,
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
"IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::LayoutAttr": $layout)>
];

let hasVerifier = 1;
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,15 @@ void removeLayoutAttrs(Operation *op);

/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
/// it to the owner's dictionary attributes
/// If `respectPermLayout` is true the existing permament layout
/// attribute will be kept and assigned to the attribute dict instead
/// of the provided layout.
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
void setDistributeLayoutAttr(const T &operandOrResult,
const DistributeLayoutAttr layout);
const DistributeLayoutAttr layout,
bool respectPermLayout = false);

/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
/// operation. If the operation contains regions, it is also applied recursively
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l3_hint=*/xegpu::CachePolicyAttr{},
/*layout=*/nullptr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#163071 to fill the gap


rewriter.replaceOp(readOp, gatherOp.getResult());
return success();
Expand Down Expand Up @@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l3_hint=*/xegpu::CachePolicyAttr{},
/*layout=*/nullptr);
rewriter.eraseOp(writeOp);
return success();
}
Expand Down Expand Up @@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l3_hint=*/xegpu::CachePolicyAttr{},
/*layout=*/nullptr);

auto selectOp =
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
Expand Down Expand Up @@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l3_hint=*/xegpu::CachePolicyAttr{},
/*layout=*/nullptr);
rewriter.eraseOp(scatterOp);
return success();
}
Expand Down
41 changes: 37 additions & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
l1_hint, l2_hint, l3_hint);
l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Expand All @@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint);
l2_hint, l3_hint, /*layout=*/nullptr);
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
xegpu::LayoutAttr layout) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint, layout);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
l2_hint, l3_hint);
l2_hint, l3_hint, /*layout=*/nullptr);
}

void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Expand All @@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,

// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint);
l3_hint, /*layout=*/nullptr);
}

void StoreScatterOp::build(
OpBuilder &builder, OperationState &state, Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);

// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint, layout);
}

//===----------------------------------------------------------------------===//
Expand Down
15 changes: 11 additions & 4 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,16 @@ void LayoutInfoPropagation::visitStoreScatterOp(
if (dstTdescTy.getChunkSizeAsInt() > 1)
instData.push_back(chunkSize);
}
LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
/*scattered=*/true);

LayoutInfo payloadLayout;

if (auto layout = storeScatter.getLayoutAttr()) {
payloadLayout = LayoutInfo(layout);
} else {
payloadLayout = getDefaultSIMTLayoutInfo(
payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
/*scattered=*/true);
}

LayoutInfo maskLayout =
getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
Expand Down Expand Up @@ -1041,7 +1048,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
xegpu::setDistributeLayoutAttr(result, layout);
xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
}
return success();
}
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,16 @@ struct UnrollLoadGatherOpWithOffset
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}

auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
if (layout)
layout = layout.dropInstData();

SmallVector<Value> newOps;
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
auto newOp = xegpu::LoadGatherOp::create(
rewriter, loc, newValueTy, op.getSource(), o, m,
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}

Expand Down Expand Up @@ -774,12 +778,16 @@ struct UnrollStoreScatterOpWithOffsets
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
if (layout)
layout = layout.dropInstData();

for (auto [v, o, m] :
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
rewriter.getI64IntegerAttr(chunkSize),
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
op.getL3HintAttr(), layout);
}

rewriter.eraseOp(op);
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,8 @@ struct WgToSgLoadGatherOpWithOffset
return failure();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
xegpu::getDistributeLayoutAttr(op.getResult()));
if (!layout || !layout.isForWorkgroup())
return failure();

Expand All @@ -914,9 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
auto newLoadOp = xegpu::LoadGatherOp::create(
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
layout.dropSgLayoutAndData());
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
layout.dropSgLayoutAndData());
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
Expand All @@ -941,8 +940,8 @@ struct WgToSgStoreScatterOpWithOffset
if (!valueType)
return failure();

xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getOperand(0));
xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
xegpu::getDistributeLayoutAttr(op.getOperand(0)));
if (!layout || !layout.isForWorkgroup())
return failure();

Expand All @@ -964,7 +963,8 @@ struct WgToSgStoreScatterOpWithOffset
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
auto store = xegpu::StoreScatterOp::create(
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
layout.dropSgLayoutAndData());
// Update the layout attribute to drop sg_layout and sg_data.
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
Expand Down
65 changes: 60 additions & 5 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);

// check for "permament" layout only after "temporary" layout name lookup
// for backward compatibility
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no store_matrix here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since the store scatter does not return a result it is not needed here. We only need load gather.

return loadGatherOp.getLayoutAttr();
}

if (auto arg = dyn_cast<BlockArgument>(value)) {
Expand Down Expand Up @@ -171,27 +176,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);

// check for "permament" layout only after "temporary" layout name lookup
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should handle the anchore ops first. So all following ops should be proceesed first here.

Load/Store matrix
load/store nd
loadgather/storescatter

I also think store ops can be omitted (but fine to have in code), because they don't return anything. so they can never be used as an OpOperand

Copy link
Contributor Author

@dchigarev dchigarev Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.

this function handles OpOperands while the permament layout attr for the load-op describes OpResult, so no reason to access the layout here

if (auto layout = storeScatterOp.getLayoutAttr())
return layout;

return getDistributeLayoutAttr(opr.get());
}

// Returns the permanent layout attribute for the given result if it's
// available on the defining op. Otherwise returns the provided layout.
xegpu::DistributeLayoutAttr
maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
const OpResult &result, mlir::Operation *owner,
const std::string &name) {
xegpu::DistributeLayoutAttr candidate = layout;

if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
if (auto perm = loadOp.getLayoutAttr())
candidate = perm;
}

return candidate;
}

// Returns the permanent layout attribute for the given operand if it's
// available on the defining op. Otherwise returns the provided layout.
xegpu::DistributeLayoutAttr
maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
const OpOperand &operand, mlir::Operation *owner,
const std::string &name) {
xegpu::DistributeLayoutAttr candidate = layout;
unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();

if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
if (idx == 0) {
if (auto perm = storeOp.getLayoutAttr())
candidate = perm;
}
}

return candidate;
}

template <typename T, typename>
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
const DistributeLayoutAttr layout) {
const DistributeLayoutAttr layout,
bool respectPermLayout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->setAttr(name, layout);

if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
return;

DistributeLayoutAttr candidate = layout;
if (respectPermLayout)
candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name);

if (candidate)
owner->setAttr(name, candidate);
}

// Explicit instantiation for OpResult
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result,
const mlir::xegpu::DistributeLayoutAttr layout);
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);

// Explicit instantiation for OpOperand
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
const mlir::OpOperand &operand,
const mlir::xegpu::DistributeLayoutAttr layout);
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);

void xegpu::setDistributeLayoutAttrs(
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
Expand Down
Loading