Skip to content

Commit 21646cb

Browse files
committed
convert_layout transform op returns a handle to created convert op
1 parent 1a03d07 commit 21646cb

File tree

5 files changed

+23
-20
lines changed

5 files changed

+23
-20
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
257257
failure is emitted if source layout cannot be found. An
258258
`xegpu.convert_layout` op, whose destination layout is defined by the
259259
`sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
260-
before the first use of the value.
260+
before the first use of the value. Returns a handle to the emitted
261+
`xegpu.convert_layout` op.
261262
}];
262263

263264
let arguments = (ins TransformValueHandleTypeInterface:$target,
@@ -269,7 +270,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
269270
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
270271
);
271272

272-
let results = (outs);
273+
let results = (outs TransformHandleTypeInterface:$newConvertOp);
273274
let builders = [
274275
OpBuilder<(ins "Value":$target,
275276
"ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -283,7 +284,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
283284
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
284285
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
285286
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
286-
attr-dict `:` qualified(type(operands))
287+
attr-dict `:` functional-type(operands, results)
287288
}];
288289

289290
let extraClassDeclaration = [{

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -614,18 +614,18 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
614614
<< "Value has no users to insert layout conversion.";
615615
Operation *userOp = *value.getUsers().begin();
616616

617-
if (producerLayoutAttr != targetLayoutAttr) {
618-
rewriter.setInsertionPoint(userOp);
619-
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
620-
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
621-
targetLayoutAttr);
622-
// Replace load op result with the converted layout.
623-
rewriter.replaceUsesWithIf(
624-
value, convLayoutOp.getResult(), [&](OpOperand &use) {
625-
return use.getOwner() != convLayoutOp.getOperation();
626-
});
627-
}
628-
617+
// Emit convert_layout op.
618+
rewriter.setInsertionPoint(userOp);
619+
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
620+
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
621+
targetLayoutAttr);
622+
// Replace load op result with the converted layout.
623+
rewriter.replaceUsesWithIf(
624+
value, convLayoutOp.getResult(), [&](OpOperand &use) {
625+
return use.getOwner() != convLayoutOp.getOperation();
626+
});
627+
628+
results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
629629
return DiagnosedSilenceableFailure::success();
630630
}
631631

@@ -635,6 +635,7 @@ void transform::ConvertLayoutOp::getEffects(
635635
onlyReadsHandle(getSgLayoutMutable(), effects);
636636
onlyReadsHandle(getSgDataMutable(), effects);
637637
onlyReadsHandle(getInstDataMutable(), effects);
638+
producesHandle(getOperation()->getOpResults(), effects);
638639
modifiesPayload(effects);
639640
}
640641

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def __init__(
295295
_,
296296
) = _dispatch_dynamic_index_list(inst_data)
297297
super().__init__(
298+
transform.AnyOpType.get(),
298299
target,
299300
dynamic_sg_layout,
300301
dynamic_sg_data,
@@ -323,4 +324,4 @@ def convert_layout(
323324
inst_data=inst_data,
324325
loc=loc,
325326
ip=ip,
326-
)
327+
).result

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ module attributes {transform.with_named_sequence} {
172172
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
173173
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
174174
// expected-error@below {{Could not find a layout attribute in the producer chain.}}
175-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
175+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
176176
transform.yield
177177
}
178178
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ module attributes {transform.with_named_sequence} {
427427
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
428428
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
429429
// CHECK: transform.xegpu.convert_layout %{{.*}}
430-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
430+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
431431
transform.yield
432432
}
433433
}
@@ -459,7 +459,7 @@ module attributes {transform.with_named_sequence} {
459459
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
460460
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
461461
// CHECK: transform.xegpu.convert_layout %{{.*}}
462-
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
462+
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
463463
transform.yield
464464
}
465465
}
@@ -487,7 +487,7 @@ module attributes {transform.with_named_sequence} {
487487
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
488488
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
489489
// CHECK: transform.xegpu.convert_layout %{{.*}}
490-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
490+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
491491
transform.yield
492492
}
493493
}

0 commit comments

Comments
 (0)