Skip to content

Commit e9fc393

Browse files
authored
[MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)
`set_op_layout_attr` and `set_desc_layout` transform ops wrap `xegpu.layout` in an `xegpu.slice` attribute if `slice_dims` argument is set.
1 parent 4c81b92 commit e9fc393

File tree

5 files changed

+139
-15
lines changed

5 files changed

+139
-15
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
4242

4343
let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
4444
let description = [{
45-
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
46-
attribute to the result tensor descriptor. The layout is defined by the
47-
`sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
48-
to the transformed op.
45+
Given an `xegpu.create_nd_desc` operation, this transform adds
46+
`xegpu.layout` attribute to the result tensor descriptor. The layout is
47+
defined by the `sg_layout`, and `sg_data` and optional `inst_data`
48+
attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
49+
wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to
50+
the transformed op.
4951
}];
5052

5153
let arguments = (ins
@@ -55,15 +57,17 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
5557
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
5658
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
5759
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
58-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
60+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
61+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
5962
);
6063

6164
let results = (outs TransformHandleTypeInterface:$transformed);
6265
let builders = [
6366
OpBuilder<(ins "Value":$target,
6467
"ArrayRef<OpFoldResult>":$mixedSgLayout,
6568
"ArrayRef<OpFoldResult>":$mixedSgData,
66-
"ArrayRef<OpFoldResult>":$mixedInstData
69+
"ArrayRef<OpFoldResult>":$mixedInstData,
70+
"ArrayRef<int64_t>":$sliceDims
6771
)>,
6872
];
6973

@@ -72,6 +76,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
7276
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
7377
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
7478
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
79+
(`slice_dims` `=` $slice_dims^)?
7580
attr-dict `:` functional-type(operands, results)
7681
}];
7782

@@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
107112
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
108113
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
109114
target operand/result value is defined by the `index` argument. The layout
110-
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
115+
is defined by the `sg_layout`, `sg_data` and optional `inst_data`
116+
attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
117+
wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
111118
}];
112119

113120
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
118125
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
119126
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
120127
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
128+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
121129
DefaultValuedAttr<UnitAttr, "false">:$result
122130
);
123131

@@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
128136
"ArrayRef<OpFoldResult>":$mixedSgLayout,
129137
"ArrayRef<OpFoldResult>":$mixedSgData,
130138
"ArrayRef<OpFoldResult>":$mixedInstData,
139+
"ArrayRef<int64_t>":$sliceDims,
131140
CArg<"bool", "false">:$result
132141
)>,
133142
];
@@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
137146
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
138147
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
139148
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
149+
(`slice_dims` `=` $slice_dims^)?
140150
attr-dict `:` qualified(type(operands))
141151
}];
142152

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
167167
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
168168
static xegpu::CreateNdDescOp
169169
setDescLayout(transform::TransformRewriter &rewriter,
170-
xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
170+
xegpu::CreateNdDescOp descOp,
171+
xegpu::DistributeLayoutAttr layout) {
171172
assert(descOp.getMixedOffsets().size() == 0 &&
172173
"create desc op with offsets is not supported");
173174
auto oldTensorDesc = descOp.getType();
@@ -212,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
212213
OperationState &result, Value target,
213214
ArrayRef<OpFoldResult> mixedSgLayout,
214215
ArrayRef<OpFoldResult> mixedSgData,
215-
ArrayRef<OpFoldResult> mixedInstData) {
216+
ArrayRef<OpFoldResult> mixedInstData,
217+
ArrayRef<int64_t> sliceDims) {
216218
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
217219
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
218220
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -225,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
225227
/*inst_data=*/dynamicInstData,
226228
/*static_sg_layout=*/staticSgLayout,
227229
/*static_sg_data=*/staticSgData,
228-
/*static_inst_data=*/staticInstData);
230+
/*static_inst_data=*/staticInstData,
231+
/*slice_dims=*/sliceDims);
229232
}
230233

231234
DiagnosedSilenceableFailure
@@ -246,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
246249
if (!status.succeeded())
247250
return status;
248251

252+
xegpu::DistributeLayoutAttr layout = layoutAttr;
253+
auto sliceDims = getSliceDims();
254+
if (sliceDims.size() > 0) {
255+
// Wrap layoutAttr in a slice attribute.
256+
layout = xegpu::SliceAttr::get(
257+
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
258+
}
259+
249260
// For now only create_nd_desc op is supported.
250261
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
251262
if (!descOp) {
@@ -257,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
257268
}
258269

259270
// Set layout attr in desc op's return type. Replaces old desc op.
260-
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
271+
auto newdescOp = setDescLayout(rewriter, descOp, layout);
261272

262273
// Map result handles.
263274
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
@@ -278,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects(
278289
void transform::SetOpLayoutAttrOp::build(
279290
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
280291
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
281-
ArrayRef<OpFoldResult> mixedInstData, bool result) {
292+
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
293+
bool result) {
282294
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
283295
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
284296
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -293,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build(
293305
/*static_sg_layout=*/staticSgLayout,
294306
/*static_sg_data=*/staticSgData,
295307
/*static_inst_data=*/staticInstData,
308+
/*slice_dims=*/sliceDims,
296309
/*result=*/result);
297310
}
298311

@@ -326,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
326339
if (!status.succeeded())
327340
return status;
328341

342+
xegpu::DistributeLayoutAttr layout = layoutAttr;
343+
auto sliceDims = getSliceDims();
344+
if (sliceDims.size() > 0) {
345+
// Wrap layoutAttr in a slice attribute.
346+
layout = xegpu::SliceAttr::get(
347+
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
348+
}
349+
329350
// Set layout attribute for the op result or operand
330351
if (resultTarget)
331-
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
352+
xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
332353
else
333-
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
354+
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
334355
return DiagnosedSilenceableFailure::success();
335356
}
336357

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
sg_data: MixedValues,
6363
*,
6464
inst_data: Optional[MixedValues] = None,
65+
slice_dims: Optional[MixedInt] = None,
6566
loc=None,
6667
ip=None,
6768
):
@@ -92,6 +93,7 @@ def __init__(
9293
static_sg_layout=static_sg_layout,
9394
static_sg_data=static_sg_data,
9495
static_inst_data=static_inst_data,
96+
slice_dims=slice_dims,
9597
loc=loc,
9698
ip=ip,
9799
)
@@ -103,6 +105,7 @@ def set_desc_layout(
103105
sg_data: MixedValues,
104106
*,
105107
inst_data: Optional[MixedValues] = None,
108+
slice_dims: Optional[MixedInt] = None,
106109
loc=None,
107110
ip=None,
108111
) -> OpResult:
@@ -111,6 +114,7 @@ def set_desc_layout(
111114
sg_layout,
112115
sg_data,
113116
inst_data=inst_data,
117+
slice_dims=slice_dims,
114118
loc=loc,
115119
ip=ip,
116120
).result
@@ -127,6 +131,7 @@ def __init__(
127131
sg_data: MixedValues,
128132
*,
129133
inst_data: Optional[MixedValues] = None,
134+
slice_dims: Optional[MixedInt] = None,
130135
index: Optional[Union[int, Attribute]] = None,
131136
result: Optional[Union[bool, Attribute]] = None,
132137
loc=None,
@@ -156,6 +161,7 @@ def __init__(
156161
static_sg_layout=static_sg_layout,
157162
static_sg_data=static_sg_data,
158163
static_inst_data=static_inst_data,
164+
slice_dims=slice_dims,
159165
index=index,
160166
result=result,
161167
loc=loc,
@@ -169,6 +175,7 @@ def set_op_layout_attr(
169175
sg_data: MixedValues,
170176
*,
171177
inst_data: Optional[MixedValues] = None,
178+
slice_dims: Optional[MixedInt] = None,
172179
index: Optional[Union[int, Attribute]] = None,
173180
result: Optional[Union[bool, Attribute]] = None,
174181
loc=None,
@@ -179,6 +186,7 @@ def set_op_layout_attr(
179186
sg_layout,
180187
sg_data,
181188
inst_data=inst_data,
189+
slice_dims=slice_dims,
182190
index=index,
183191
result=result,
184192
loc=loc,

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} {
121121

122122
// -----
123123

124+
// CHECK-LABEL: @set_desc_layout_slice
125+
func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
126+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
127+
// CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
128+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
129+
return
130+
}
131+
132+
module attributes {transform.with_named_sequence} {
133+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
134+
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
135+
// CHECK: transform.xegpu.set_desc_layout %{{.*}}
136+
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
137+
transform.yield
138+
}
139+
}
140+
141+
// -----
142+
124143
// CHECK-LABEL: @set_op_layout_attr_result_default_index
125144
func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
126145
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} {
212231

213232
// -----
214233

234+
// CHECK-LABEL: @set_op_layout_attr_result_slice
235+
func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
236+
// CHECK: = arith.extf
237+
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
238+
%2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
239+
return
240+
}
241+
242+
module attributes {transform.with_named_sequence} {
243+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
244+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
245+
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
246+
transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
247+
transform.yield
248+
}
249+
}
250+
251+
// -----
252+
215253
// CHECK-LABEL: @set_op_layout_attr_operand_minimal
216254
func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
217255
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def setDescLayoutInstData():
6666
# CHECK: inst_data = [8, 16]
6767

6868

69+
@run
70+
def setDescLayoutSlice():
71+
sequence = transform.SequenceOp(
72+
transform.FailurePropagationMode.Propagate,
73+
[],
74+
transform.OperationType.get("xegpu.create_nd_tdesc"),
75+
)
76+
with InsertionPoint(sequence.body):
77+
xegpu.set_desc_layout(
78+
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
79+
)
80+
transform.YieldOp()
81+
# CHECK-LABEL: TEST: setDescLayoutSlice
82+
# CHECK: %0 = transform.xegpu.set_desc_layout %
83+
# CHECK: sg_layout = [6, 4]
84+
# CHECK: sg_data = [32, 16]
85+
# CHECK: slice_dims = [0]
86+
87+
6988
@run
7089
def setOpLayoutAttrOperandMinimal():
7190
sequence = transform.SequenceOp(
@@ -106,13 +125,41 @@ def setOpLayoutAttrResult():
106125
result=True,
107126
)
108127
transform.YieldOp()
109-
# CHECK-LABEL: TEST: setOpLayoutAttr
128+
# CHECK-LABEL: TEST: setOpLayoutAttrResult
129+
# CHECK: transform.xegpu.set_op_layout_attr %
130+
# NO-CHECK: index = 0
131+
# CHECK: result
132+
# CHECK: sg_layout = [6, 4]
133+
# CHECK: sg_data = [32, 16]
134+
# CHECK: inst_data = [8, 16]
135+
136+
137+
@run
138+
def setOpLayoutAttrResultSlice():
139+
sequence = transform.SequenceOp(
140+
transform.FailurePropagationMode.Propagate,
141+
[],
142+
transform.OperationType.get("xegpu.dpas"),
143+
)
144+
with InsertionPoint(sequence.body):
145+
xegpu.set_op_layout_attr(
146+
sequence.bodyTarget,
147+
index=0,
148+
sg_layout=[6, 4],
149+
sg_data=[32, 16],
150+
inst_data=[8, 16],
151+
slice_dims=[0],
152+
result=True,
153+
)
154+
transform.YieldOp()
155+
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
110156
# CHECK: transform.xegpu.set_op_layout_attr %
111157
# NO-CHECK: index = 0
112158
# CHECK: result
113159
# CHECK: sg_layout = [6, 4]
114160
# CHECK: sg_data = [32, 16]
115161
# CHECK: inst_data = [8, 16]
162+
# CHECK: slice_dims = [0]
116163

117164

118165
@run

0 commit comments

Comments
 (0)