Skip to content

Commit f3af2c3

Browse files
committed
update convert_layout
1 parent a84014f commit f3af2c3

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
217217
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
218218
"xegpu::DistributeLayoutAttr",
219219
"dropSgLayoutAndData">,
220+
InterfaceMethod<"Derive a new layout by dropping InstData",
221+
"xegpu::DistributeLayoutAttr",
222+
"dropInstData">,
220223
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
221224
indices based on the effective subgroup layout.}],
222225
"FailureOr<SmallVector<Value>>",

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,8 +1162,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
11621162
the IR is lowered to WI level because that is the end result of all distributions.
11631163
}];
11641164
let arguments = (ins XeGPU_VectorType: $source,
1165-
XeGPU_LayoutAttr: $input_layout,
1166-
XeGPU_LayoutAttr: $target_layout);
1165+
DistributeLayoutAttr: $input_layout,
1166+
DistributeLayoutAttr: $target_layout);
11671167
let results = (outs XeGPU_VectorType: $result);
11681168
let assemblyFormat = [{
11691169
$source prop-dict attr-dict `:` type($source)

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct ConvertLayoutOpPattern
8484
using OpRewritePattern::OpRewritePattern;
8585
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
8686
PatternRewriter &rewriter) const override {
87-
xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
88-
xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
89-
if (!input_layout.getInstData() || !target_layout.getInstData())
87+
xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
88+
xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
89+
if (!input_layout.getInstDataAsInt() || !target_layout.getInstDataAsInt())
9090
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
9191

9292
input_layout = input_layout.dropInstData();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,9 @@ struct WgToSgConvertLayoutOp
613613
LogicalResult
614614
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
615615
ConversionPatternRewriter &rewriter) const override {
616-
xegpu::LayoutAttr input = op.getInputLayout();
617-
xegpu::LayoutAttr target = op.getTargetLayout();
616+
// TODO: currently, we only support LayoutAttr
617+
auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
618+
auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
618619

619620
if (!input || !target || !input.isForWorkgroup() ||
620621
!target.isForWorkgroup())

0 commit comments

Comments
 (0)