Skip to content

Commit a038699

Browse files
committed
[mlir][xegpu][transformops] add set_op_layout_attr op
1 parent bda7289 commit a038699

File tree

6 files changed

+459
-19
lines changed

6 files changed

+459
-19
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
7878
}];
7979
}
8080

81+
def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
82+
AttrSizedOperandSegments,
83+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
84+
TransformOpInterface
85+
]> {
86+
87+
let summary = "Set xegpu.layout attribute of an op.";
88+
let description = [{
89+
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
90+
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
91+
target operand/result value is defined by the `index` argument. The layout
92+
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
93+
}];
94+
95+
let arguments = (ins TransformHandleTypeInterface : $target,
96+
DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
97+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
98+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
99+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
100+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
101+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
102+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
103+
DefaultValuedAttr<UnitAttr, "false">:$result
104+
);
105+
106+
let results = (outs);
107+
let builders = [
108+
OpBuilder<(ins "Value":$target,
109+
"int64_t":$index,
110+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
111+
"ArrayRef<OpFoldResult>":$mixedSgData,
112+
"ArrayRef<OpFoldResult>":$mixedInstData,
113+
CArg<"bool", "false">:$result
114+
)>,
115+
];
116+
117+
let assemblyFormat = [{
118+
$target (`result` $result^)? (`index` `=` $index^)?
119+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
120+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
121+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
122+
attr-dict `:` qualified(type(operands))
123+
}];
124+
125+
let extraClassDeclaration = [{
126+
::mlir::DiagnosedSilenceableFailure apply(
127+
::mlir::transform::TransformRewriter &rewriter,
128+
::mlir::transform::TransformResults &transformResults,
129+
::mlir::transform::TransformState &state);
130+
131+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
132+
Builder b(getContext());
133+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
134+
}
135+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
136+
Builder b(getContext());
137+
return getMixedValues(getStaticSgData(), getSgData(), b);
138+
}
139+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
140+
Builder b(getContext());
141+
return getMixedValues(getStaticInstData(), getInstData(), b);
142+
}
143+
}];
144+
}
145+
81146
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
9090
/*order=*/nullptr);
9191
}
9292

93+
/// Generate `xegpu::LayoutAttr` from op mixed layout values.
94+
DiagnosedSilenceableFailure
95+
getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
96+
transform::TransformState &state,
97+
TransformOpInterface transformOp,
98+
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
99+
ArrayRef<::mlir::OpFoldResult> mixedSgData,
100+
ArrayRef<::mlir::OpFoldResult> mixedInstData,
101+
xegpu::LayoutAttr &layoutAttr) {
102+
SmallVector<int32_t> sgLayout, sgData, instData;
103+
auto status =
104+
convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
105+
if (!status.succeeded())
106+
return status;
107+
108+
status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
109+
if (!status.succeeded())
110+
return status;
111+
112+
status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
113+
if (!status.succeeded())
114+
return status;
115+
auto maybeInstData = instData.empty()
116+
? std::nullopt
117+
: std::optional<ArrayRef<int32_t>>(instData);
118+
119+
layoutAttr =
120+
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
121+
122+
return DiagnosedSilenceableFailure::success();
123+
}
124+
93125
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
94126
static xegpu::CreateNdDescOp
95127
setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
142174
}
143175
Operation *target = *targetOps.begin();
144176

145-
SmallVector<int32_t> sgLayout;
146-
DiagnosedSilenceableFailure status =
147-
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
177+
xegpu::LayoutAttr layoutAttr = nullptr;
178+
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
179+
getMixedSgLayout(), getMixedSgData(),
180+
getMixedInstData(), layoutAttr);
148181
if (!status.succeeded())
149182
return status;
150183

151-
SmallVector<int32_t> sgData;
152-
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
153-
if (!status.succeeded())
154-
return status;
155-
156-
SmallVector<int32_t> instData;
157-
status =
158-
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
159-
if (!status.succeeded())
160-
return status;
161-
auto maybeInstData = instData.empty()
162-
? std::nullopt
163-
: std::optional<ArrayRef<int32_t>>(instData);
164-
165184
// For now only create_nd_desc op is supported.
166185
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
167186
if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
173192
}
174193

175194
// Set layout attr in desc op's return type. Replaces old desc op.
176-
auto layoutAttr =
177-
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
178195
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
179196

180197
// Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
193210
modifiesPayload(effects);
194211
}
195212

213+
void transform::SetOpLayoutAttrOp::build(
214+
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
215+
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
216+
ArrayRef<OpFoldResult> mixedInstData, bool result) {
217+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
218+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
219+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
220+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
221+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
222+
build(builder, ostate, target.getType(),
223+
/*target=*/target,
224+
/*index=*/index,
225+
/*sg_layout=*/dynamicSgLayout,
226+
/*sg_data=*/dynamicSgData,
227+
/*inst_data=*/dynamicInstData,
228+
/*static_sg_layout=*/staticSgLayout,
229+
/*static_sg_data=*/staticSgData,
230+
/*static_inst_data=*/staticInstData,
231+
/*result=*/result);
232+
}
233+
234+
DiagnosedSilenceableFailure
235+
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
236+
transform::TransformResults &results,
237+
transform::TransformState &state) {
238+
239+
auto targetOps = state.getPayloadOps(getTarget());
240+
if (!llvm::hasSingleElement(targetOps)) {
241+
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
242+
<< llvm::range_size(targetOps) << ")";
243+
}
244+
Operation *target = *targetOps.begin();
245+
246+
bool resultTarget = getResult();
247+
248+
int64_t index = getIndex();
249+
if (resultTarget && index >= target->getNumResults()) {
250+
return emitSilenceableFailure(getLoc())
251+
<< "Index exceeds the number of op results";
252+
}
253+
if (!resultTarget && index >= target->getNumOperands()) {
254+
return emitSilenceableFailure(getLoc())
255+
<< "Index exceeds the number of op operands";
256+
}
257+
258+
xegpu::LayoutAttr layoutAttr = nullptr;
259+
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
260+
getMixedSgLayout(), getMixedSgData(),
261+
getMixedInstData(), layoutAttr);
262+
if (!status.succeeded())
263+
return status;
264+
265+
// Set layout attribute for the op result or operand
266+
if (resultTarget) {
267+
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
268+
} else {
269+
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
270+
}
271+
return DiagnosedSilenceableFailure::success();
272+
}
273+
274+
void transform::SetOpLayoutAttrOp::getEffects(
275+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
276+
onlyReadsHandle(getTargetMutable(), effects);
277+
onlyReadsHandle(getSgLayoutMutable(), effects);
278+
onlyReadsHandle(getSgDataMutable(), effects);
279+
onlyReadsHandle(getInstDataMutable(), effects);
280+
modifiesPayload(effects);
281+
}
282+
196283
namespace {
197284
class XeGPUTransformDialectExtension
198285
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,50 @@ def __init__(
6464
loc=loc,
6565
ip=ip,
6666
)
67+
68+
69+
@_ods_cext.register_operation(_Dialect, replace=True)
70+
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
71+
"""Specialization for SetOpLayoutAttrOp class."""
72+
73+
def __init__(
74+
self,
75+
target: Union[Operation, Value],
76+
sg_layout: MixedValues,
77+
sg_data: MixedValues,
78+
*,
79+
inst_data: MixedValues = None,
80+
index: Union[int, Attribute] = None,
81+
result: Union[bool, Attribute] = None,
82+
loc=None,
83+
ip=None,
84+
):
85+
inst_data = [] if inst_data is None else inst_data
86+
(
87+
dynamic_sg_layout,
88+
static_sg_layout,
89+
_,
90+
) = _dispatch_dynamic_index_list(sg_layout)
91+
(
92+
dynamic_sg_data,
93+
static_sg_data,
94+
_,
95+
) = _dispatch_dynamic_index_list(sg_data)
96+
(
97+
dynamic_inst_data,
98+
static_inst_data,
99+
_,
100+
) = _dispatch_dynamic_index_list(inst_data)
101+
super().__init__(
102+
_get_op_result_or_value(target),
103+
dynamic_sg_layout,
104+
dynamic_sg_data,
105+
dynamic_inst_data,
106+
static_sg_layout=static_sg_layout,
107+
static_sg_data=static_sg_data,
108+
static_inst_data=static_inst_data,
109+
index=index,
110+
result=result,
111+
loc=loc,
112+
ip=ip,
113+
)

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
1313
transform.yield
1414
}
1515
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: @set_op_layout_attr_bad_result_index
20+
func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
21+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
22+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
23+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
24+
return
25+
}
26+
27+
module attributes {transform.with_named_sequence} {
28+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
29+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
30+
// expected-error@below {{Index exceeds the number of op results}}
31+
transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
32+
transform.yield
33+
}
34+
}
35+
36+
// -----
37+
38+
// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
39+
func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
40+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
41+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
42+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
43+
return
44+
}
45+
46+
module attributes {transform.with_named_sequence} {
47+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
48+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49+
// expected-error@below {{Index exceeds the number of op operands}}
50+
transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
51+
transform.yield
52+
}
53+
}
54+
55+
// -----
56+
57+
// CHECK-LABEL: @set_op_layout_attr_multiple
58+
func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
59+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
60+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
61+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
62+
%3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
63+
return
64+
}
65+
66+
module attributes {transform.with_named_sequence} {
67+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
69+
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
70+
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
71+
transform.yield
72+
}
73+
}

0 commit comments

Comments
 (0)