Skip to content

Commit 515c56b

Browse files
committed
[mlir][xegpu][transformops] add convert_layout op
1 parent 7838dbe commit 515c56b

File tree

5 files changed

+292
-1
lines changed

5 files changed

+292
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,66 @@ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
244244
}];
245245
}
246246

247+
def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
248+
AttrSizedOperandSegments,
249+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
250+
TransformOpInterface
251+
]> {
252+
253+
let summary = "Convert xegpu.layout attribute for a value.";
254+
let description = [{
255+
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
256+
of a value. First, the `xegpu.load_nd` producer op of the value is found.
257+
It must already be annotated with a layout. An `xegpu.convert_layout` op,
258+
whose destination layout is defined by the `sg_layout`, `sg_data` and
259+
optional `inst_data` attributes, is inserted after the load op.
260+
}];
261+
262+
let arguments = (ins TransformValueHandleTypeInterface:$target,
263+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
264+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
265+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
266+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
267+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
268+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
269+
);
270+
271+
let results = (outs);
272+
let builders = [
273+
OpBuilder<(ins "Value":$target,
274+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
275+
"ArrayRef<OpFoldResult>":$mixedSgData,
276+
"ArrayRef<OpFoldResult>":$mixedInstData
277+
)>,
278+
];
279+
280+
let assemblyFormat = [{
281+
$target
282+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
283+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
284+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
285+
attr-dict `:` qualified(type(operands))
286+
}];
287+
288+
let extraClassDeclaration = [{
289+
::mlir::DiagnosedSilenceableFailure apply(
290+
::mlir::transform::TransformRewriter &rewriter,
291+
::mlir::transform::TransformResults &transformResults,
292+
::mlir::transform::TransformState &state);
293+
294+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
295+
Builder b(getContext());
296+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
297+
}
298+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
299+
Builder b(getContext());
300+
return getMixedValues(getStaticSgData(), getSgData(), b);
301+
}
302+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
303+
Builder b(getContext());
304+
return getMixedValues(getStaticInstData(), getInstData(), b);
305+
}
306+
}];
307+
}
308+
247309
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,85 @@ void transform::InsertPrefetchOp::getEffects(
537537
modifiesPayload(effects);
538538
}
539539

540+
void transform::ConvertLayoutOp::build(OpBuilder &builder,
541+
OperationState &ostate, Value target,
542+
ArrayRef<OpFoldResult> mixedSgLayout,
543+
ArrayRef<OpFoldResult> mixedSgData,
544+
ArrayRef<OpFoldResult> mixedInstData) {
545+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
546+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
547+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
548+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
549+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
550+
build(builder, ostate, target.getType(),
551+
/*target=*/target,
552+
/*sg_layout=*/dynamicSgLayout,
553+
/*sg_data=*/dynamicSgData,
554+
/*inst_data=*/dynamicInstData,
555+
/*static_sg_layout=*/staticSgLayout,
556+
/*static_sg_data=*/staticSgData,
557+
/*static_inst_data=*/staticInstData);
558+
}
559+
560+
DiagnosedSilenceableFailure
561+
transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
562+
transform::TransformResults &results,
563+
transform::TransformState &state) {
564+
auto targetValues = state.getPayloadValues(getTarget());
565+
if (!llvm::hasSingleElement(targetValues)) {
566+
return emitDefiniteFailure()
567+
<< "requires exactly one target value handle (got "
568+
<< llvm::range_size(targetValues) << ")";
569+
}
570+
571+
auto value = *targetValues.begin();
572+
573+
xegpu::LayoutAttr layoutAttr = nullptr;
574+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
575+
getMixedSgLayout(), getMixedSgData(),
576+
getMixedInstData(), layoutAttr);
577+
if (!status.succeeded())
578+
return status;
579+
580+
// Get load op.
581+
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
582+
if (!maybeLoadOp) {
583+
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
584+
}
585+
auto loadOp = *maybeLoadOp;
586+
// Get load op operand value layout
587+
auto producerLayoutAttr =
588+
xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
589+
if (!producerLayoutAttr) {
590+
return emitSilenceableFailure(getLoc())
591+
<< "Operand producer op does not have a layout attr.";
592+
}
593+
594+
if (producerLayoutAttr != layoutAttr) {
595+
rewriter.setInsertionPointAfter(loadOp.getOperation());
596+
auto source = loadOp.getResult();
597+
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
598+
rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
599+
layoutAttr);
600+
// Replace load op result with the converted layout.
601+
rewriter.replaceUsesWithIf(
602+
source, convLayoutOp.getResult(), [&](OpOperand &use) {
603+
return use.getOwner() != convLayoutOp.getOperation();
604+
});
605+
}
606+
607+
return DiagnosedSilenceableFailure::success();
608+
}
609+
610+
void transform::ConvertLayoutOp::getEffects(
611+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
612+
onlyReadsHandle(getTargetMutable(), effects);
613+
onlyReadsHandle(getSgLayoutMutable(), effects);
614+
onlyReadsHandle(getSgDataMutable(), effects);
615+
onlyReadsHandle(getInstDataMutable(), effects);
616+
modifiesPayload(effects);
617+
}
618+
540619
namespace {
541620
class XeGPUTransformDialectExtension
542621
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,47 @@ def insert_prefetch(
210210
loc=None,
211211
ip=None,
212212
) -> OpResult:
213-
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
213+
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
214+
215+
216+
@_ods_cext.register_operation(_Dialect, replace=True)
217+
class ConvertLayoutOp(ConvertLayoutOp):
218+
"""Specialization for ConvertLayoutOp class."""
219+
220+
def __init__(
221+
self,
222+
target: Value,
223+
sg_layout: MixedValues,
224+
sg_data: MixedValues,
225+
*,
226+
inst_data: Optional[MixedValues] = None,
227+
loc=None,
228+
ip=None,
229+
):
230+
inst_data = [] if inst_data is None else inst_data
231+
(
232+
dynamic_sg_layout,
233+
static_sg_layout,
234+
_,
235+
) = _dispatch_dynamic_index_list(sg_layout)
236+
(
237+
dynamic_sg_data,
238+
static_sg_data,
239+
_,
240+
) = _dispatch_dynamic_index_list(sg_data)
241+
(
242+
dynamic_inst_data,
243+
static_inst_data,
244+
_,
245+
) = _dispatch_dynamic_index_list(inst_data)
246+
super().__init__(
247+
target,
248+
dynamic_sg_layout,
249+
dynamic_sg_data,
250+
dynamic_inst_data,
251+
static_sg_layout=static_sg_layout,
252+
static_sg_data=static_sg_data,
253+
static_inst_data=static_inst_data,
254+
loc=loc,
255+
ip=ip,
256+
)

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,66 @@ module attributes {transform.with_named_sequence} {
400400
transform.yield
401401
}
402402
}
403+
404+
// -----
405+
406+
// CHECK-LABEL: @convert_layout_a
407+
func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
408+
%c0 = arith.constant 0 : index
409+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
410+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
411+
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
412+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
413+
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
414+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
415+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
416+
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
417+
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
418+
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
419+
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
420+
// CHECK: = xegpu.dpas %[[V2]]
421+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
422+
return
423+
}
424+
425+
module attributes {transform.with_named_sequence} {
426+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
427+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
428+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
429+
// CHECK: transform.xegpu.convert_layout %{{.*}}
430+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
431+
transform.yield
432+
}
433+
}
434+
435+
// -----
436+
437+
// CHECK-LABEL: @convert_layout_a_sg_param
438+
func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
439+
%c0 = arith.constant 0 : index
440+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
441+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
442+
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
443+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
444+
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
445+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
446+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
447+
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
448+
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
449+
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
450+
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
451+
// CHECK: = xegpu.dpas %[[V2]]
452+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
453+
return
454+
}
455+
456+
module attributes {transform.with_named_sequence} {
457+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
458+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
459+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
460+
// CHECK: transform.xegpu.convert_layout %{{.*}}
461+
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
462+
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
463+
transform.yield
464+
}
465+
}

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,47 @@ def insertPrefetchNbPrefetchParam():
193193
# CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
194194
# CHECK: transform.xegpu.insert_prefetch %[[OPR]]
195195
# CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
196+
197+
198+
@run
199+
def ConvertLayoutMinimal():
200+
sequence = transform.SequenceOp(
201+
transform.FailurePropagationMode.Propagate,
202+
[],
203+
transform.OperationType.get("xegpu.dpas"),
204+
)
205+
with InsertionPoint(sequence.body):
206+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
207+
xegpu.ConvertLayoutOp(
208+
operand,
209+
sg_layout=[6, 4],
210+
sg_data=[32, 16],
211+
)
212+
transform.YieldOp()
213+
# CHECK-LABEL: TEST: ConvertLayoutMinimal
214+
# CHECK: transform.xegpu.convert_layout %
215+
# CHECK: sg_layout = [6, 4]
216+
# CHECK: sg_data = [32, 16]
217+
218+
219+
@run
220+
def ConvertLayout():
221+
sequence = transform.SequenceOp(
222+
transform.FailurePropagationMode.Propagate,
223+
[],
224+
transform.OperationType.get("xegpu.dpas"),
225+
)
226+
with InsertionPoint(sequence.body):
227+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
228+
xegpu.ConvertLayoutOp(
229+
operand,
230+
sg_layout=[6, 4],
231+
sg_data=[32, 16],
232+
inst_data=[8, 16],
233+
)
234+
transform.YieldOp()
235+
# CHECK-LABEL: TEST: ConvertLayout
236+
# CHECK: transform.xegpu.convert_layout %
237+
# CHECK: sg_layout = [6, 4]
238+
# CHECK: sg_data = [32, 16]
239+
# CHECK: inst_data = [8, 16]

0 commit comments

Comments
 (0)