Skip to content

Commit b94a37f

Browse files
committed
Add pattern for load_gather and store_scatter ops
1 parent 334e9bf commit b94a37f

File tree

3 files changed

+145
-1
lines changed

3 files changed

+145
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
751751

752752
let builders = [
753753
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
754+
"xegpu::CachePolicyAttr": $l1_hint,
755+
"xegpu::CachePolicyAttr": $l2_hint,
756+
"xegpu::CachePolicyAttr": $l3_hint)>,
757+
OpBuilder<(ins "Type": $value, "Value": $source,
758+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
759+
"IntegerAttr": $chunk_size,
754760
"xegpu::CachePolicyAttr": $l1_hint,
755761
"xegpu::CachePolicyAttr": $l2_hint,
756762
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -859,6 +865,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
859865

860866
let builders = [
861867
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
868+
"xegpu::CachePolicyAttr": $l1_hint,
869+
"xegpu::CachePolicyAttr": $l2_hint,
870+
"xegpu::CachePolicyAttr": $l3_hint)>,
871+
OpBuilder<(ins "Value": $value, "Value": $dest,
872+
"ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
873+
"IntegerAttr": $chunk_size,
862874
"xegpu::CachePolicyAttr": $l1_hint,
863875
"xegpu::CachePolicyAttr": $l2_hint,
864876
"xegpu::CachePolicyAttr": $l3_hint)>

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
737737
l1_hint, l2_hint, l3_hint);
738738
}
739739

740+
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
741+
Type valueType, Value source,
742+
ArrayRef<OpFoldResult> offsets, Value mask,
743+
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
744+
xegpu::CachePolicyAttr l2_hint,
745+
xegpu::CachePolicyAttr l3_hint) {
746+
auto loc = source.getLoc();
747+
int64_t size = static_cast<int64_t>(offsets.size());
748+
auto type = VectorType::get(size, builder.getIndexType());
749+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
750+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
751+
752+
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
753+
l2_hint, l3_hint);
754+
}
755+
740756
//===----------------------------------------------------------------------===//
741757
// XeGPU_StoreScatterOp
742758
//===----------------------------------------------------------------------===//
@@ -785,6 +801,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
785801
l2_hint, l3_hint);
786802
}
787803

804+
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
805+
Value value, Value dest,
806+
ArrayRef<OpFoldResult> offsets, Value mask,
807+
IntegerAttr chunk_size,
808+
xegpu::CachePolicyAttr l1_hint,
809+
xegpu::CachePolicyAttr l2_hint,
810+
xegpu::CachePolicyAttr l3_hint) {
811+
auto loc = dest.getLoc();
812+
int64_t size = static_cast<int64_t>(offsets.size());
813+
auto type = VectorType::get(size, builder.getIndexType());
814+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
815+
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
816+
817+
// Call the correct builder overload that does not expect result types.
818+
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
819+
l3_hint);
820+
}
821+
788822
//===----------------------------------------------------------------------===//
789823
// XeGPU_UpdateOffsetOp
790824
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
685685
}
686686
};
687687

688+
// This pattern transforms the LoadGatherOp with explicit offsets to load
689+
// subgroup data, similar to WgToSgLoadNdOpWithOffset.
690+
struct WgToSgLoadGatherOpWithOffset
691+
: public OpConversionPattern<xegpu::LoadGatherOp> {
692+
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
693+
LogicalResult
694+
matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
695+
ConversionPatternRewriter &rewriter) const override {
696+
697+
if (!op.getOffsets())
698+
return failure();
699+
700+
Location loc = op.getLoc();
701+
VectorType resultType = op.getResult().getType();
702+
ArrayRef<int64_t> wgShape = resultType.getShape();
703+
704+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
705+
if (!layout || !layout.getSgLayout())
706+
return failure();
707+
708+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
709+
710+
SmallVector<Value> newLoadOps;
711+
auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
712+
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
713+
for (auto [offsets, mask] :
714+
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
715+
auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
716+
loc, newTy, op.getSource(), offsets, mask,
717+
chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(),
718+
op.getL3HintAttr());
719+
xegpu::setLayoutAttr(newLoadOp->getResult(0),
720+
layout.dropSgLayoutAndData());
721+
newLoadOps.push_back(newLoadOp);
722+
}
723+
rewriter.replaceOpWithMultiple(op, {newLoadOps});
724+
return success();
725+
}
726+
};
727+
728+
// This pattern transforms the StoreScatterOp with explicit offsets to store
729+
// subgroup data, similar to WgToSgStoreNdOpWithOffset.
730+
struct WgToSgStoreScatterOpWithOffset
731+
: public OpConversionPattern<xegpu::StoreScatterOp> {
732+
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
733+
LogicalResult
734+
matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
735+
ConversionPatternRewriter &rewriter) const override {
736+
737+
if (!op.getOffsets())
738+
return failure();
739+
740+
Location loc = op.getLoc();
741+
VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
742+
if (!valueType)
743+
return failure();
744+
745+
ArrayRef<int64_t> wgShape = valueType.getShape();
746+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
747+
if (!layout || !layout.getSgLayout())
748+
return failure();
749+
750+
auto chunkSizeOpt = op.getChunkSize();
751+
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
752+
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
753+
for (auto [val, offs, mask] : llvm::zip(
754+
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
755+
rewriter.create<xegpu::StoreScatterOp>(
756+
loc, val, op.getDest(), offs, mask, chunkSizeAttr,
757+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
758+
// Update the layout_result_0 attribute to drop sg_layout and sg_data.
759+
if (auto layoutAttr =
760+
op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
761+
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
762+
op->setAttr("layout_result_0", newLayout);
763+
}
764+
}
765+
rewriter.eraseOp(op);
766+
return success();
767+
}
768+
};
769+
688770
} // namespace
689771

690772
namespace mlir {
@@ -694,7 +776,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
694776
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
695777
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
696778
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
697-
WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
779+
WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
780+
WgToSgStoreScatterOpWithOffset>(
698781
patterns.getContext());
699782
}
700783
} // namespace xegpu
@@ -815,6 +898,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
815898
return isLegal(xegpu::getLayoutAttr(op.getResult()));
816899
});
817900

901+
target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
902+
[=](xegpu::LoadGatherOp op) -> bool {
903+
auto layout = xegpu::getLayoutAttr(op.getResult());
904+
return isLegal(layout);
905+
});
906+
907+
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
908+
[=](xegpu::StoreScatterOp op) -> bool {
909+
// Check if the layout attribute is present on the result.
910+
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
911+
if (!layout)
912+
return true;
913+
return isLegal(layout);
914+
});
915+
818916
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
819917
[=](xegpu::ConvertLayoutOp op) -> bool {
820918
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

0 commit comments

Comments
 (0)