Skip to content

Commit 5f6ff5d

Browse files
committed
convert_layout transform op returns a handle to created convert op
1 parent ee597cd commit 5f6ff5d

File tree

5 files changed

+23
-20
lines changed

5 files changed

+23
-20
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
213213
failure is emitted if source layout cannot be found. An
214214
`xegpu.convert_layout` op, whose destination layout is defined by the
215215
`sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
216-
before the first use of the value.
216+
before the first use of the value. Returns a handle to the emitted
217+
`xegpu.convert_layout` op.
217218
}];
218219

219220
let arguments = (ins TransformValueHandleTypeInterface:$target,
@@ -225,7 +226,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
225226
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
226227
);
227228

228-
let results = (outs);
229+
let results = (outs TransformHandleTypeInterface:$newConvertOp);
229230
let builders = [
230231
OpBuilder<(ins "Value":$target,
231232
"ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -239,7 +240,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
239240
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
240241
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
241242
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
242-
attr-dict `:` qualified(type(operands))
243+
attr-dict `:` functional-type(operands, results)
243244
}];
244245

245246
let extraClassDeclaration = [{

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,18 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
482482
<< "Value has no users to insert layout conversion.";
483483
Operation *userOp = *value.getUsers().begin();
484484

485-
if (producerLayoutAttr != targetLayoutAttr) {
486-
rewriter.setInsertionPoint(userOp);
487-
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
488-
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
489-
targetLayoutAttr);
490-
// Replace load op result with the converted layout.
491-
rewriter.replaceUsesWithIf(
492-
value, convLayoutOp.getResult(), [&](OpOperand &use) {
493-
return use.getOwner() != convLayoutOp.getOperation();
494-
});
495-
}
496-
485+
// Emit convert_layout op.
486+
rewriter.setInsertionPoint(userOp);
487+
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
488+
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
489+
targetLayoutAttr);
490+
// Replace load op result with the converted layout.
491+
rewriter.replaceUsesWithIf(
492+
value, convLayoutOp.getResult(), [&](OpOperand &use) {
493+
return use.getOwner() != convLayoutOp.getOperation();
494+
});
495+
496+
results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
497497
return DiagnosedSilenceableFailure::success();
498498
}
499499

@@ -503,6 +503,7 @@ void transform::ConvertLayoutOp::getEffects(
503503
onlyReadsHandle(getSgLayoutMutable(), effects);
504504
onlyReadsHandle(getSgDataMutable(), effects);
505505
onlyReadsHandle(getInstDataMutable(), effects);
506+
producesHandle(getOperation()->getOpResults(), effects);
506507
modifiesPayload(effects);
507508
}
508509

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
_,
254254
) = _dispatch_dynamic_index_list(inst_data)
255255
super().__init__(
256+
transform.AnyOpType.get(),
256257
target,
257258
dynamic_sg_layout,
258259
dynamic_sg_data,
@@ -281,4 +282,4 @@ def convert_layout(
281282
inst_data=inst_data,
282283
loc=loc,
283284
ip=ip,
284-
)
285+
).result

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ module attributes {transform.with_named_sequence} {
141141
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
142142
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
143143
// expected-error@below {{Could not find a layout attribute in the producer chain.}}
144-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
144+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
145145
transform.yield
146146
}
147147
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ module attributes {transform.with_named_sequence} {
335335
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
336336
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
337337
// CHECK: transform.xegpu.convert_layout %{{.*}}
338-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
338+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
339339
transform.yield
340340
}
341341
}
@@ -367,7 +367,7 @@ module attributes {transform.with_named_sequence} {
367367
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
368368
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
369369
// CHECK: transform.xegpu.convert_layout %{{.*}}
370-
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
370+
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
371371
transform.yield
372372
}
373373
}
@@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} {
395395
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
396396
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
397397
// CHECK: transform.xegpu.convert_layout %{{.*}}
398-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
398+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
399399
transform.yield
400400
}
401401
}

0 commit comments

Comments
 (0)