Skip to content

Commit 317dae1

Browse files
chencha3adam-smnk
andauthored
[mlir][xegpu] Add initial skeleton implementation for lowering ConvertLayoutOp (#146176)
This PR adds initial skeleton implementation for lowering ConvertLayoutOp. It currently only supports cases where SLM is not needed. --------- Co-authored-by: Adam Siemieniuk <[email protected]>
1 parent f97adea commit 317dae1

File tree

10 files changed

+245
-50
lines changed

10 files changed

+245
-50
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,21 +1010,22 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
10101010
def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["source", "result"]>]> {
10111011
let summary = "Convert the layout of the input operand";
10121012
let description = [{
1013-
`convert_layout` adjusts the data distribution across subgroups and/or work-items by modifying
1014-
the `LayoutAttr`. Both `srcMap` and `resMap` must correspond to the same programming scope, such
1015-
as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once the IR is
1016-
lowered to WI level because that is the end result of all distributions.
1013+
`convert_layout` redistribute data across subgroups and/or work-items from the `input_layout` to
1014+
the `target_layout`. Both `input_layout` and `target_layout` must correspond to the same programming
1015+
scope, such as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once
1016+
the IR is lowered to WI level because that is the end result of all distributions.
10171017
}];
1018-
let arguments = (ins XeGPU_Vector2DType: $source,
1019-
XeGPU_LayoutAttr: $srcMap,
1020-
XeGPU_LayoutAttr: $resMap
1021-
);
1022-
let results = (outs XeGPU_Vector2DType: $result);
1018+
let arguments = (ins XeGPU_VectorType: $source,
1019+
XeGPU_LayoutAttr: $input_layout,
1020+
XeGPU_LayoutAttr: $target_layout);
1021+
let results = (outs XeGPU_VectorType: $result);
10231022
let assemblyFormat = [{
1024-
$source attr-dict `:` type($source)
1023+
$source prop-dict attr-dict `:` type($source)
10251024
}];
10261025

1026+
let hasFolder = 1;
10271027
let hasVerifier = 1;
1028+
let hasCanonicalizer = 1;
10281029
}
10291030

10301031
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
2222
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
2323
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
2424
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
25-
def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
25+
def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
2626

2727
// common base class for types in XeGPU dialect
2828
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -786,32 +786,55 @@ LogicalResult DpasOp::verify() {
786786
// XeGPU_ConvertLayoutOp
787787
//===----------------------------------------------------------------------===//
788788
LogicalResult ConvertLayoutOp::verify() {
789-
auto srcMap = getSrcMapAttr();
790-
auto resMap = getResMapAttr();
791-
if (!srcMap)
792-
return emitOpError("expected srcMap.");
793-
if (!resMap)
794-
return emitOpError("expected resMap.");
795-
796-
if (srcMap == resMap)
797-
return emitOpError("expected different srcMap and resMap.");
798-
799-
// both srcMap and resMap should be WgLayout or SgLayout at the same time.
800-
if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
801-
(!srcMap.isSgLayout() || !resMap.isSgLayout()))
802-
return emitOpError(
803-
"expected srcMap and resMap be WgLayout or SgLayout at the same time.");
789+
auto srcLayout = getInputLayout();
790+
auto resLayout = getTargetLayout();
791+
if (!srcLayout)
792+
return emitOpError("expected input layout.");
793+
if (!resLayout)
794+
return emitOpError("expected target layout.");
795+
796+
// both input and target layouts should be WgLayout or SgLayout at the same
797+
// time.
798+
if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
799+
(!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
800+
return emitOpError("expected input layout and target layout be WgLayout or "
801+
"SgLayout at the same time.");
804802

805803
auto shape = getSource().getType().getShape();
806-
if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
807-
return emitOpError("invalid srcMap, data cannot be evenly distributed.");
804+
if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
805+
return emitOpError(
806+
"invalid input layout, data cannot be evenly distributed.");
808807

809-
if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
810-
return emitOpError("invalid resMap, data cannot be evenly distributed.");
808+
if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
809+
return emitOpError(
810+
"invalid target layout, data cannot be evenly distributed.");
811811

812812
return mlir::success();
813813
}
814814

815+
OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
816+
if (getInputLayout() == getTargetLayout())
817+
return getSource();
818+
return {};
819+
}
820+
821+
struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
822+
using OpRewritePattern<xegpu::ConvertLayoutOp>::OpRewritePattern;
823+
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
824+
PatternRewriter &rewriter) const override {
825+
if (op.getInputLayout() == op.getTargetLayout()) {
826+
rewriter.replaceOp(op, op.getSource());
827+
return success();
828+
}
829+
return failure();
830+
}
831+
};
832+
833+
void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
834+
MLIRContext *context) {
835+
patterns.add<FoldConvertLayoutOp>(context);
836+
}
837+
815838
} // namespace xegpu
816839
} // namespace mlir
817840

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,29 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
7676
}
7777
}
7878

79+
// This pattern lowers ConvertLayoutOp by removing the inst_data field from the
80+
// layout attributes. Since both producer and consumer operations handle data
81+
// partitioning based on their own inst_data, while maintaining original input
82+
// and output shape, ConvertLayoutOp does not need to manage inst_data.
83+
struct ConvertLayoutOpPattern
84+
: public OpRewritePattern<xegpu::ConvertLayoutOp> {
85+
using OpRewritePattern::OpRewritePattern;
86+
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
87+
PatternRewriter &rewriter) const override {
88+
xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
89+
xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
90+
if (!input_layout.getInstData() || !target_layout.getInstData())
91+
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
92+
93+
input_layout = input_layout.dropInstData();
94+
target_layout = target_layout.dropInstData();
95+
auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
96+
op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
97+
rewriter.replaceOp(op, newOp);
98+
return success();
99+
}
100+
};
101+
79102
//===------------------------------------------------------------------------===//
80103
// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
81104
// to partition operations that process large shapes into multiple operations on
@@ -331,6 +354,7 @@ void XeGPUBlockingPass::runOnOperation() {
331354
});
332355

333356
RewritePatternSet patterns(ctx);
357+
patterns.add<ConvertLayoutOpPattern>(ctx);
334358

335359
vector::UnrollVectorOptions vectorOptions;
336360
vectorOptions.setNativeShapeFn(options.nativeShape);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
106106
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
107107

108108
// Calculate offset for each subgroup
109-
SmallVector<OpFoldResult>
109+
static SmallVector<OpFoldResult>
110110
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
111111
const SmallVector<OpFoldResult> &originalOffsets,
112112
const SmallVector<Value> &localOffset,
113113
const SmallVector<int64_t> &distUnitBaseAddr,
114-
const SmallVector<int64_t> &distUnitShape) const {
114+
const SmallVector<int64_t> &distUnitShape) {
115115
assert(localOffset.size() == distUnitBaseAddr.size() &&
116116
"localOffset and distUnitBaseAddr must have the same rank");
117117

@@ -466,6 +466,75 @@ struct WgToSgElementwiseOp : public ConversionPattern {
466466
}
467467
};
468468

469+
// clang-format off
470+
// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
471+
// If input_layout and target_layout have identical sg_layout and sg_data,
472+
// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
473+
// dropped. For example:
474+
// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
475+
// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
476+
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
477+
// becomes:
478+
// #a = #xegpu.layout<inst_data = [16, 16]>
479+
// #b = #xegpu.layout<inst_data = [8, 16]>
480+
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
481+
// (vector<16x16xf32> is determined by sg_data = [16, 16])
482+
//
483+
// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
484+
// For example:
485+
// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
486+
// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
487+
// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
488+
// is lowered to:
489+
// #a = #xegpu.layout<inst_data = [16, 16]>
490+
// #b = #xegpu.layout<inst_data = [8, 16]>
491+
// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
492+
// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
493+
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
494+
// clang-format on
495+
struct WgToSgConvertLayoutOp
496+
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
497+
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
498+
LogicalResult
499+
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
500+
ConversionPatternRewriter &rewriter) const override {
501+
xegpu::LayoutAttr input = op.getInputLayout();
502+
xegpu::LayoutAttr target = op.getTargetLayout();
503+
504+
if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
505+
return rewriter.notifyMatchFailure(
506+
op, "Input and target layouts must have subgroup layout");
507+
508+
DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
509+
DenseI32ArrayAttr inputSgData = input.getSgData();
510+
DenseI32ArrayAttr inputOrder = input.getOrder();
511+
DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
512+
DenseI32ArrayAttr targetSgData = target.getSgData();
513+
DenseI32ArrayAttr targetOrder = target.getOrder();
514+
515+
// TODO: currently we only support for optimal case, where input and
516+
// output has the same sg_layout and sg_data, so SLM is not involved.
517+
if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
518+
inputOrder != targetOrder)
519+
return failure();
520+
521+
input = input.dropSgLayoutAndData();
522+
target = target.dropSgLayoutAndData();
523+
524+
SmallVector<Value> newOps(adaptor.getSource());
525+
if (input && target) {
526+
// keep the ConvertLayoutOp for rest fields, e.g., inst_data.
527+
for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
528+
auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
529+
op.getLoc(), src.getType(), src, input, target);
530+
newOps[i] = newOp;
531+
}
532+
}
533+
rewriter.replaceOpWithMultiple(op, {newOps});
534+
return success();
535+
}
536+
};
537+
469538
// Handles UnrealizedConversionCastOp generated during
470539
// SCFStructuralTypeConversions (step 1). This op may appear as either a
471540
// target or source materialization for Vector values, e.g.:
@@ -550,7 +619,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
550619
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
551620
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
552621
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
553-
WgToSgVectorBroadcastOp>(patterns.getContext());
622+
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
623+
patterns.getContext());
554624
}
555625
} // namespace xegpu
556626
} // namespace mlir
@@ -662,6 +732,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
662732
return isLegal(xegpu::getLayoutAttr(op.getResult()));
663733
});
664734

735+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
736+
[=](xegpu::ConvertLayoutOp op) -> bool {
737+
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
738+
});
739+
665740
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
666741
[=](Operation *op) -> std::optional<bool> {
667742
// Only handle elementwise mappable ops

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
124124
Operation *defOp = result.getDefiningOp();
125125
assert(defOp && "result must have a defining op");
126126

127+
// For ConvertLayoutOp, the layout is stored in the targetLayoutAttr
128+
if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
129+
return convertOp.getTargetLayoutAttr();
130+
127131
// for LoadNdOp, the layout is stored in the tensor descriptor
128132
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
129133
return getLayoutAttr(loadNd.getTensorDesc());
@@ -137,7 +141,8 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
137141
auto parentOp = arg.getOwner()->getParentOp();
138142
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
139143
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
140-
return getLayoutAttr(tiedInit->get());
144+
if (tiedInit)
145+
return getLayoutAttr(tiedInit->get());
141146
}
142147
}
143148

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -548,19 +548,11 @@ func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vecto
548548
return
549549
}
550550

551-
// -----
552-
func.func @convert_layout_same_map(%a: vector<32x64xf16>) {
553-
// expected-error@+1 {{expected different srcMap and resMap}}
554-
%2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
555-
resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
556-
gpu.return
557-
}
558-
559551
// -----
560552
func.func @convert_layout_unmatch(%a: vector<32x64xf16>) {
561-
// expected-error@+1 {{expected srcMap and resMap be WgLayout or SgLayout at the same time}}
562-
%2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
563-
resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
553+
// expected-error@+1 {{expected input layout and target layout be WgLayout or SgLayout at the same time}}
554+
%2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
555+
target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
564556
gpu.return
565557
}
566558

mlir/test/Dialect/XeGPU/layout.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,18 @@ gpu.func @create_nd_tdesc_wg_1(%src: memref<24x32xf32>) {
3535
}
3636

3737
gpu.func @convert_layout(%a: vector<32x64xf16>) {
38-
%2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
39-
resMap = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
38+
// CHECK: xegpu.convert_layout
39+
// CHECK-SAME: <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
40+
%2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
41+
target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
4042
gpu.return
4143
}
4244

4345
gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
44-
%2 = xegpu.convert_layout %a {srcMap = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
45-
resMap = #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 32], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<32x64xf16>
46+
// CHECK: xegpu.convert_layout
47+
// CHECK-SAME: <{input_layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 32], lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
48+
%2 = xegpu.convert_layout %a <{input_layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
49+
target_layout = #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 32], lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<32x64xf16>
4650
gpu.return
4751
}
4852

0 commit comments

Comments
 (0)