Skip to content

Commit e9459dc

Browse files
committed
generalize mechanism to find producer layout
1 parent 4e8767c commit e9459dc

File tree

4 files changed

+98
-22
lines changed

4 files changed

+98
-22
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
253253
let summary = "Convert xegpu.layout attribute for a value.";
254254
let description = [{
255255
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
256-
of a value. First, the `xegpu.load_nd` producer op of the value is found.
257-
It must already be annotated with a layout. An `xegpu.convert_layout` op,
258-
whose destination layout is defined by the `sg_layout`, `sg_data` and
259-
optional `inst_data` attributes, is inserted after the load op.
256+
of a value. The source layout is inferred by inspecting the producer ops. A
257+
failure is emitted if source layout cannot be found. An
258+
`xegpu.convert_layout` op, whose destination layout is defined by the
259+
`sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
260+
before the first use of the value.
260261
}];
261262

262263
let arguments = (ins TransformValueHandleTypeInterface:$target,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,33 @@ static std::optional<T> findProducerOfType(Value val) {
120120
return findProducerOfType<T>(producerOp->getOperand(0));
121121
}
122122

123+
/// Find layout attribute in producer chain.
124+
/// Traces producer ops until a layout attribute is found. Only traces through
125+
/// ops with a single operand, in other cases the op's result layout attribute
126+
/// must be set. Returns std::nullopt if no layout attribute is found.
127+
xegpu::LayoutAttr findProducerLayout(Value val) {
128+
// Get layout attr from value or producer's attribute or operand.
129+
if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
130+
xegpu::getDistributeLayoutAttr(val)))
131+
return layoutAttr;
132+
133+
// Recurse up the producer chain.
134+
Operation *producerOp = val.getDefiningOp();
135+
if (!producerOp) {
136+
LDBG() << "Failed to find producer op.";
137+
return nullptr;
138+
}
139+
if (producerOp->getNumOperands() == 0) {
140+
LDBG() << "Producer has no operands.";
141+
return nullptr;
142+
}
143+
if (producerOp->getNumOperands() > 1) {
144+
LDBG() << "Producer has multiple operands.";
145+
return nullptr;
146+
}
147+
return findProducerLayout(producerOp->getOperand(0));
148+
}
149+
123150
/// Create a layout attribute from the given parameters.
124151
static xegpu::LayoutAttr
125152
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -568,34 +595,33 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
568595
<< llvm::range_size(targetValues) << ")";
569596
auto value = *targetValues.begin();
570597

571-
xegpu::LayoutAttr layoutAttr = nullptr;
598+
xegpu::LayoutAttr targetLayoutAttr = nullptr;
572599
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
573600
getMixedSgLayout(), getMixedSgData(),
574-
getMixedInstData(), layoutAttr);
601+
getMixedInstData(), targetLayoutAttr);
575602
if (!status.succeeded())
576603
return status;
577604

578-
// Get load op.
579-
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
580-
if (!maybeLoadOp)
581-
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
582-
auto loadOp = *maybeLoadOp;
583-
// Get load op operand value layout
584-
auto producerLayoutAttr =
585-
xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
605+
// Find source layout attribute from the producer chain.
606+
auto producerLayoutAttr = findProducerLayout(value);
586607
if (!producerLayoutAttr)
587608
return emitSilenceableFailure(getLoc())
588-
<< "Operand producer op does not have a layout attr.";
609+
<< "Could not find a layout attribute in the producer chain.";
610+
611+
// Find first user op to define insertion point for layout conversion.
612+
if (value.use_empty())
613+
return emitSilenceableFailure(getLoc())
614+
<< "Value has no users to insert layout conversion.";
615+
Operation *userOp = *value.getUsers().begin();
589616

590-
if (producerLayoutAttr != layoutAttr) {
591-
rewriter.setInsertionPointAfter(loadOp.getOperation());
592-
auto source = loadOp.getResult();
617+
if (producerLayoutAttr != targetLayoutAttr) {
618+
rewriter.setInsertionPoint(userOp);
593619
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
594-
rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
595-
layoutAttr);
620+
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
621+
targetLayoutAttr);
596622
// Replace load op result with the converted layout.
597623
rewriter.replaceUsesWithIf(
598-
source, convLayoutOp.getResult(), [&](OpOperand &use) {
624+
value, convLayoutOp.getResult(), [&](OpOperand &use) {
599625
return use.getOwner() != convLayoutOp.getOperation();
600626
});
601627
}

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,24 @@ module attributes {transform.with_named_sequence} {
155155
transform.yield
156156
}
157157
}
158+
159+
// -----
160+
161+
// CHECK-LABEL: @convert_layout_no_producer_attr
162+
func.func @convert_layout_no_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
163+
%c0 = arith.constant 0 : index
164+
%0 = arith.addf %arg0, %arg1 : vector<32x32xf16>
165+
%1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
166+
%2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
167+
return
168+
}
169+
170+
module attributes {transform.with_named_sequence} {
171+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
172+
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
173+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
174+
// expected-error@below {{Could not find a layout attribute in the producer chain.}}
175+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
176+
transform.yield
177+
}
178+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,37 @@ module attributes {transform.with_named_sequence} {
457457
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
458458
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
459459
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
460-
// CHECK: transform.xegpu.convert_layout %{{.*}}
461460
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
461+
// CHECK: transform.xegpu.convert_layout %{{.*}}
462462
transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
463463
transform.yield
464464
}
465465
}
466+
467+
// -----
468+
469+
// CHECK-LABEL: @convert_layout_producer_attr
470+
func.func @convert_layout_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
471+
%c0 = arith.constant 0 : index
472+
%0 = arith.addf %arg0, %arg1 {layout_result_0 =
473+
#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>} :
474+
vector<32x32xf16>
475+
// CHECK: %[[V0:.+]] = arith.extf
476+
%1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
477+
// CHECK: %[[V1:.+]] = xegpu.convert_layout %[[V0]]
478+
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
479+
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
480+
// CHECK: %[[V0:.+]] = arith.truncf %[[V1]]
481+
%2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
482+
return
483+
}
484+
485+
module attributes {transform.with_named_sequence} {
486+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
487+
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
488+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
489+
// CHECK: transform.xegpu.convert_layout %{{.*}}
490+
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
491+
transform.yield
492+
}
493+
}

0 commit comments

Comments
 (0)