Skip to content

Commit 280693d

Browse files
committed
generalize mechanism to find producer layout
1 parent 549c81b commit 280693d

File tree

4 files changed

+98
-22
lines changed

4 files changed

+98
-22
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
209209
let summary = "Convert xegpu.layout attribute for a value.";
210210
let description = [{
211211
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
212-
of a value. First, the `xegpu.load_nd` producer op of the value is found.
213-
It must already be annotated with a layout. An `xegpu.convert_layout` op,
214-
whose destination layout is defined by the `sg_layout`, `sg_data` and
215-
optional `inst_data` attributes, is inserted after the load op.
212+
of a value. The source layout is inferred by inspecting the producer ops. A
213+
failure is emitted if source layout cannot be found. An
214+
`xegpu.convert_layout` op, whose destination layout is defined by the
215+
`sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
216+
before the first use of the value.
216217
}];
217218

218219
let arguments = (ins TransformValueHandleTypeInterface:$target,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,33 @@ static std::optional<T> findProducerOfType(Value val) {
119119
return findProducerOfType<T>(producerOp->getOperand(0));
120120
}
121121

122+
/// Find layout attribute in producer chain.
123+
/// Traces producer ops until a layout attribute is found. Only traces through
124+
/// ops with a single operand, in other cases the op's result layout attribute
125+
/// must be set. Returns std::nullopt if no layout attribute is found.
126+
xegpu::LayoutAttr findProducerLayout(Value val) {
127+
// Get layout attr from value or producer's attribute or operand.
128+
if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
129+
xegpu::getDistributeLayoutAttr(val)))
130+
return layoutAttr;
131+
132+
// Recurse up the producer chain.
133+
Operation *producerOp = val.getDefiningOp();
134+
if (!producerOp) {
135+
LDBG() << "Failed to find producer op.";
136+
return nullptr;
137+
}
138+
if (producerOp->getNumOperands() == 0) {
139+
LDBG() << "Producer has no operands.";
140+
return nullptr;
141+
}
142+
if (producerOp->getNumOperands() > 1) {
143+
LDBG() << "Producer has multiple operands.";
144+
return nullptr;
145+
}
146+
return findProducerLayout(producerOp->getOperand(0));
147+
}
148+
122149
/// Create a layout attribute from the given parameters.
123150
static xegpu::LayoutAttr
124151
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -436,34 +463,33 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
436463
<< llvm::range_size(targetValues) << ")";
437464
auto value = *targetValues.begin();
438465

439-
xegpu::LayoutAttr layoutAttr = nullptr;
466+
xegpu::LayoutAttr targetLayoutAttr = nullptr;
440467
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
441468
getMixedSgLayout(), getMixedSgData(),
442-
getMixedInstData(), layoutAttr);
469+
getMixedInstData(), targetLayoutAttr);
443470
if (!status.succeeded())
444471
return status;
445472

446-
// Get load op.
447-
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
448-
if (!maybeLoadOp)
449-
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
450-
auto loadOp = *maybeLoadOp;
451-
// Get load op operand value layout
452-
auto producerLayoutAttr =
453-
xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
473+
// Find source layout attribute from the producer chain.
474+
auto producerLayoutAttr = findProducerLayout(value);
454475
if (!producerLayoutAttr)
455476
return emitSilenceableFailure(getLoc())
456-
<< "Operand producer op does not have a layout attr.";
477+
<< "Could not find a layout attribute in the producer chain.";
478+
479+
// Find first user op to define insertion point for layout conversion.
480+
if (value.use_empty())
481+
return emitSilenceableFailure(getLoc())
482+
<< "Value has no users to insert layout conversion.";
483+
Operation *userOp = *value.getUsers().begin();
457484

458-
if (producerLayoutAttr != layoutAttr) {
459-
rewriter.setInsertionPointAfter(loadOp.getOperation());
460-
auto source = loadOp.getResult();
485+
if (producerLayoutAttr != targetLayoutAttr) {
486+
rewriter.setInsertionPoint(userOp);
461487
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
462-
rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
463-
layoutAttr);
488+
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
489+
targetLayoutAttr);
464490
// Replace load op result with the converted layout.
465491
rewriter.replaceUsesWithIf(
466-
source, convLayoutOp.getResult(), [&](OpOperand &use) {
492+
value, convLayoutOp.getResult(), [&](OpOperand &use) {
467493
return use.getOwner() != convLayoutOp.getOperation();
468494
});
469495
}

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,24 @@ module attributes {transform.with_named_sequence} {
124124
transform.yield
125125
}
126126
}
127+
128+
// -----
129+
130+
// CHECK-LABEL: @convert_layout_no_producer_attr
131+
func.func @convert_layout_no_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
132+
%c0 = arith.constant 0 : index
133+
%0 = arith.addf %arg0, %arg1 : vector<32x32xf16>
134+
%1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
135+
%2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
136+
return
137+
}
138+
139+
module attributes {transform.with_named_sequence} {
140+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
141+
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
142+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
143+
// expected-error@below {{Could not find a layout attribute in the producer chain.}}
144+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
145+
transform.yield
146+
}
147+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,37 @@ module attributes {transform.with_named_sequence} {
365365
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
366366
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
367367
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
368-
// CHECK: transform.xegpu.convert_layout %{{.*}}
369368
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
369+
// CHECK: transform.xegpu.convert_layout %{{.*}}
370370
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
371371
transform.yield
372372
}
373373
}
374+
375+
// -----
376+
377+
// CHECK-LABEL: @convert_layout_producer_attr
378+
func.func @convert_layout_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
379+
%c0 = arith.constant 0 : index
380+
%0 = arith.addf %arg0, %arg1 {layout_result_0 =
381+
#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>} :
382+
vector<32x32xf16>
383+
// CHECK: %[[V0:.+]] = arith.extf
384+
%1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
385+
// CHECK: %[[V1:.+]] = xegpu.convert_layout %[[V0]]
386+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
387+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
388+
// CHECK: %[[V0:.+]] = arith.truncf %[[V1]]
389+
%2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
390+
return
391+
}
392+
393+
module attributes {transform.with_named_sequence} {
394+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
395+
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
396+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
397+
// CHECK: transform.xegpu.convert_layout %{{.*}}
398+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
399+
transform.yield
400+
}
401+
}

0 commit comments

Comments
 (0)