Skip to content

Commit 3fad504

Browse files
committed
[mlir][xegpu][transformops] add convert_layout op
1 parent bdcd591 commit 3fad504

File tree

5 files changed

+292
-0
lines changed

5 files changed

+292
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,66 @@ def SetGPULaunchThreadsOp
200200
}];
201201
}
202202

203+
def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
204+
AttrSizedOperandSegments,
205+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
206+
TransformOpInterface
207+
]> {
208+
209+
let summary = "Convert xegpu.layout attribute for a value.";
210+
let description = [{
211+
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
212+
of a value. First, the `xegpu.load_nd` producer op of the value is found.
213+
It must already be annotated with a layout. An `xegpu.convert_layout` op,
214+
whose destination layout is defined by the `sg_layout`, `sg_data` and
215+
optional `inst_data` attributes, is inserted after the load op.
216+
}];
217+
218+
let arguments = (ins TransformValueHandleTypeInterface:$target,
219+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
220+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
221+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
222+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
223+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
224+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
225+
);
226+
227+
let results = (outs);
228+
let builders = [
229+
OpBuilder<(ins "Value":$target,
230+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
231+
"ArrayRef<OpFoldResult>":$mixedSgData,
232+
"ArrayRef<OpFoldResult>":$mixedInstData
233+
)>,
234+
];
235+
236+
let assemblyFormat = [{
237+
$target
238+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
239+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
240+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
241+
attr-dict `:` qualified(type(operands))
242+
}];
243+
244+
let extraClassDeclaration = [{
245+
::mlir::DiagnosedSilenceableFailure apply(
246+
::mlir::transform::TransformRewriter &rewriter,
247+
::mlir::transform::TransformResults &transformResults,
248+
::mlir::transform::TransformState &state);
249+
250+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
251+
Builder b(getContext());
252+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
253+
}
254+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
255+
Builder b(getContext());
256+
return getMixedValues(getStaticSgData(), getSgData(), b);
257+
}
258+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
259+
Builder b(getContext());
260+
return getMixedValues(getStaticInstData(), getInstData(), b);
261+
}
262+
}];
263+
}
264+
203265
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,85 @@ void transform::SetGPULaunchThreadsOp::getEffects(
405405
modifiesPayload(effects);
406406
}
407407

408+
void transform::ConvertLayoutOp::build(OpBuilder &builder,
409+
OperationState &ostate, Value target,
410+
ArrayRef<OpFoldResult> mixedSgLayout,
411+
ArrayRef<OpFoldResult> mixedSgData,
412+
ArrayRef<OpFoldResult> mixedInstData) {
413+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
414+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
415+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
416+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
417+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
418+
build(builder, ostate, target.getType(),
419+
/*target=*/target,
420+
/*sg_layout=*/dynamicSgLayout,
421+
/*sg_data=*/dynamicSgData,
422+
/*inst_data=*/dynamicInstData,
423+
/*static_sg_layout=*/staticSgLayout,
424+
/*static_sg_data=*/staticSgData,
425+
/*static_inst_data=*/staticInstData);
426+
}
427+
428+
DiagnosedSilenceableFailure
429+
transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
430+
transform::TransformResults &results,
431+
transform::TransformState &state) {
432+
auto targetValues = state.getPayloadValues(getTarget());
433+
if (!llvm::hasSingleElement(targetValues)) {
434+
return emitDefiniteFailure()
435+
<< "requires exactly one target value handle (got "
436+
<< llvm::range_size(targetValues) << ")";
437+
}
438+
439+
auto value = *targetValues.begin();
440+
441+
xegpu::LayoutAttr layoutAttr = nullptr;
442+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
443+
getMixedSgLayout(), getMixedSgData(),
444+
getMixedInstData(), layoutAttr);
445+
if (!status.succeeded())
446+
return status;
447+
448+
// Get load op.
449+
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
450+
if (!maybeLoadOp) {
451+
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
452+
}
453+
auto loadOp = *maybeLoadOp;
454+
// Get load op operand value layout
455+
auto producerLayoutAttr =
456+
xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
457+
if (!producerLayoutAttr) {
458+
return emitSilenceableFailure(getLoc())
459+
<< "Operand producer op does not have a layout attr.";
460+
}
461+
462+
if (producerLayoutAttr != layoutAttr) {
463+
rewriter.setInsertionPointAfter(loadOp.getOperation());
464+
auto source = loadOp.getResult();
465+
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
466+
rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
467+
layoutAttr);
468+
// Replace load op result with the converted layout.
469+
rewriter.replaceUsesWithIf(
470+
source, convLayoutOp.getResult(), [&](OpOperand &use) {
471+
return use.getOwner() != convLayoutOp.getOperation();
472+
});
473+
}
474+
475+
return DiagnosedSilenceableFailure::success();
476+
}
477+
478+
void transform::ConvertLayoutOp::getEffects(
479+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
480+
onlyReadsHandle(getTargetMutable(), effects);
481+
onlyReadsHandle(getSgLayoutMutable(), effects);
482+
onlyReadsHandle(getSgDataMutable(), effects);
483+
onlyReadsHandle(getInstDataMutable(), effects);
484+
modifiesPayload(effects);
485+
}
486+
408487
namespace {
409488
class XeGPUTransformDialectExtension
410489
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
)
135135

136136

137+
@_ods_cext.register_operation(_Dialect, replace=True)
137138
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
138139
"""Specialization for SetGPULaunchThreadsOp class."""
139140

@@ -168,3 +169,46 @@ def set_gpu_launch_threads(
168169
ip=None,
169170
) -> SetGPULaunchThreadsOp:
170171
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
172+
173+
174+
@_ods_cext.register_operation(_Dialect, replace=True)
175+
class ConvertLayoutOp(ConvertLayoutOp):
176+
"""Specialization for ConvertLayoutOp class."""
177+
178+
def __init__(
179+
self,
180+
target: Value,
181+
sg_layout: MixedValues,
182+
sg_data: MixedValues,
183+
*,
184+
inst_data: Optional[MixedValues] = None,
185+
loc=None,
186+
ip=None,
187+
):
188+
inst_data = [] if inst_data is None else inst_data
189+
(
190+
dynamic_sg_layout,
191+
static_sg_layout,
192+
_,
193+
) = _dispatch_dynamic_index_list(sg_layout)
194+
(
195+
dynamic_sg_data,
196+
static_sg_data,
197+
_,
198+
) = _dispatch_dynamic_index_list(sg_data)
199+
(
200+
dynamic_inst_data,
201+
static_inst_data,
202+
_,
203+
) = _dispatch_dynamic_index_list(inst_data)
204+
super().__init__(
205+
target,
206+
dynamic_sg_layout,
207+
dynamic_sg_data,
208+
dynamic_inst_data,
209+
static_sg_layout=static_sg_layout,
210+
static_sg_data=static_sg_data,
211+
static_inst_data=static_inst_data,
212+
loc=loc,
213+
ip=ip,
214+
)

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,66 @@ module attributes {transform.with_named_sequence} {
308308
transform.yield
309309
}
310310
}
311+
312+
// -----
313+
314+
// CHECK-LABEL: @convert_layout_a
315+
func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
316+
%c0 = arith.constant 0 : index
317+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
318+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
319+
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
320+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
321+
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
322+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
323+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
324+
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
325+
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
326+
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
327+
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
328+
// CHECK: = xegpu.dpas %[[V2]]
329+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
330+
return
331+
}
332+
333+
module attributes {transform.with_named_sequence} {
334+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
335+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
336+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
337+
// CHECK: transform.xegpu.convert_layout %{{.*}}
338+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
339+
transform.yield
340+
}
341+
}
342+
343+
// -----
344+
345+
// CHECK-LABEL: @convert_layout_a_sg_param
346+
func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
347+
%c0 = arith.constant 0 : index
348+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
349+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
350+
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
351+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
352+
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
353+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
354+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
355+
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
356+
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
357+
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
358+
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
359+
// CHECK: = xegpu.dpas %[[V2]]
360+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
361+
return
362+
}
363+
364+
module attributes {transform.with_named_sequence} {
365+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
366+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
367+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
368+
// CHECK: transform.xegpu.convert_layout %{{.*}}
369+
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
370+
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
371+
transform.yield
372+
}
373+
}

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,47 @@ def setGPULaunchThreadsOp():
128128
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
129129
# CHECK: transform.xegpu.set_gpu_launch_threads
130130
# CHECK: threads = [8, 4, 1]
131+
132+
133+
@run
134+
def ConvertLayoutMinimal():
135+
sequence = transform.SequenceOp(
136+
transform.FailurePropagationMode.Propagate,
137+
[],
138+
transform.OperationType.get("xegpu.dpas"),
139+
)
140+
with InsertionPoint(sequence.body):
141+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
142+
xegpu.ConvertLayoutOp(
143+
operand,
144+
sg_layout=[6, 4],
145+
sg_data=[32, 16],
146+
)
147+
transform.YieldOp()
148+
# CHECK-LABEL: TEST: ConvertLayoutMinimal
149+
# CHECK: transform.xegpu.convert_layout %
150+
# CHECK: sg_layout = [6, 4]
151+
# CHECK: sg_data = [32, 16]
152+
153+
154+
@run
155+
def ConvertLayout():
156+
sequence = transform.SequenceOp(
157+
transform.FailurePropagationMode.Propagate,
158+
[],
159+
transform.OperationType.get("xegpu.dpas"),
160+
)
161+
with InsertionPoint(sequence.body):
162+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
163+
xegpu.ConvertLayoutOp(
164+
operand,
165+
sg_layout=[6, 4],
166+
sg_data=[32, 16],
167+
inst_data=[8, 16],
168+
)
169+
transform.YieldOp()
170+
# CHECK-LABEL: TEST: ConvertLayout
171+
# CHECK: transform.xegpu.convert_layout %
172+
# CHECK: sg_layout = [6, 4]
173+
# CHECK: sg_data = [32, 16]
174+
# CHECK: inst_data = [8, 16]

0 commit comments

Comments
 (0)