Skip to content

Commit 9684e82

Browse files
committed
convert_layout takes input/target layouts; remove verification
1 parent 21646cb commit 9684e82

File tree

6 files changed

+190
-168
lines changed

6 files changed

+190
-168
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,37 +253,46 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
253253
let summary = "Convert xegpu.layout attribute for a value.";
254254
let description = [{
255255
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
256-
of a value. The source layout is inferred by inspecting the producer ops. A
257-
failure is emitted if source layout cannot be found. An
258-
`xegpu.convert_layout` op, whose destination layout is defined by the
259-
`sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
260-
before the first use of the value. Returns a handle to the emitted
261-
`xegpu.convert_layout` op.
256+
of a value. The input and target layouts are defined by the `*sg_layout`,
257+
`*sg_data` and optional `*inst_data` attributes. Returns a handle to the
258+
emitted `xegpu.convert_layout` op.
262259
}];
263260

264261
let arguments = (ins TransformValueHandleTypeInterface:$target,
265-
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
266-
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
267-
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
268-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
269-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
270-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
262+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_layout,
263+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_data,
264+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_inst_data,
265+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout,
266+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data,
267+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data,
268+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout,
269+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data,
270+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data,
271+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout,
272+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data,
273+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data
271274
);
272275

273276
let results = (outs TransformHandleTypeInterface:$newConvertOp);
274277
let builders = [
275278
OpBuilder<(ins "Value":$target,
276-
"ArrayRef<OpFoldResult>":$mixedSgLayout,
277-
"ArrayRef<OpFoldResult>":$mixedSgData,
278-
"ArrayRef<OpFoldResult>":$mixedInstData
279+
"ArrayRef<OpFoldResult>":$mixedInputSgLayout,
280+
"ArrayRef<OpFoldResult>":$mixedInputSgData,
281+
"ArrayRef<OpFoldResult>":$mixedInputInstData,
282+
"ArrayRef<OpFoldResult>":$mixedTargetSgLayout,
283+
"ArrayRef<OpFoldResult>":$mixedTargetSgData,
284+
"ArrayRef<OpFoldResult>":$mixedTargetInstData
279285
)>,
280286
];
281287

282288
let assemblyFormat = [{
283289
$target
284-
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
285-
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
286-
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
290+
`input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout)
291+
`input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data)
292+
(`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)?
293+
`target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout)
294+
`target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data)
295+
(`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)?
287296
attr-dict `:` functional-type(operands, results)
288297
}];
289298

@@ -293,17 +302,30 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
293302
::mlir::transform::TransformResults &transformResults,
294303
::mlir::transform::TransformState &state);
295304

296-
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
305+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgLayout() {
297306
Builder b(getContext());
298-
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
307+
return getMixedValues(getStaticInputSgLayout(), getInputSgLayout(), b);
299308
}
300-
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
309+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgData() {
301310
Builder b(getContext());
302-
return getMixedValues(getStaticSgData(), getSgData(), b);
311+
return getMixedValues(getStaticInputSgData(), getInputSgData(), b);
303312
}
304-
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
313+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputInstData() {
305314
Builder b(getContext());
306-
return getMixedValues(getStaticInstData(), getInstData(), b);
315+
return getMixedValues(getStaticInputInstData(), getInputInstData(), b);
316+
}
317+
318+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgLayout() {
319+
Builder b(getContext());
320+
return getMixedValues(getStaticTargetSgLayout(), getTargetSgLayout(), b);
321+
}
322+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgData() {
323+
Builder b(getContext());
324+
return getMixedValues(getStaticTargetSgData(), getTargetSgData(), b);
325+
}
326+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetInstData() {
327+
Builder b(getContext());
328+
return getMixedValues(getStaticTargetInstData(), getTargetInstData(), b);
307329
}
308330
}];
309331
}

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -120,33 +120,6 @@ static std::optional<T> findProducerOfType(Value val) {
120120
return findProducerOfType<T>(producerOp->getOperand(0));
121121
}
122122

123-
/// Find layout attribute in producer chain.
124-
/// Traces producer ops until a layout attribute is found. Only traces through
125-
/// ops with a single operand, in other cases the op's result layout attribute
126-
/// must be set. Returns std::nullopt if no layout attribute is found.
127-
xegpu::LayoutAttr findProducerLayout(Value val) {
128-
// Get layout attr from value or producer's attribute or operand.
129-
if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
130-
xegpu::getDistributeLayoutAttr(val)))
131-
return layoutAttr;
132-
133-
// Recurse up the producer chain.
134-
Operation *producerOp = val.getDefiningOp();
135-
if (!producerOp) {
136-
LDBG() << "Failed to find producer op.";
137-
return nullptr;
138-
}
139-
if (producerOp->getNumOperands() == 0) {
140-
LDBG() << "Producer has no operands.";
141-
return nullptr;
142-
}
143-
if (producerOp->getNumOperands() > 1) {
144-
LDBG() << "Producer has multiple operands.";
145-
return nullptr;
146-
}
147-
return findProducerLayout(producerOp->getOperand(0));
148-
}
149-
150123
/// Create a layout attribute from the given parameters.
151124
static xegpu::LayoutAttr
152125
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -564,24 +537,48 @@ void transform::InsertPrefetchOp::getEffects(
564537
modifiesPayload(effects);
565538
}
566539

567-
void transform::ConvertLayoutOp::build(OpBuilder &builder,
568-
OperationState &ostate, Value target,
569-
ArrayRef<OpFoldResult> mixedSgLayout,
570-
ArrayRef<OpFoldResult> mixedSgData,
571-
ArrayRef<OpFoldResult> mixedInstData) {
572-
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
573-
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
574-
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
575-
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
576-
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
540+
void transform::ConvertLayoutOp::build(
541+
OpBuilder &builder, OperationState &ostate, Value target,
542+
ArrayRef<OpFoldResult> mixedInputSgLayout,
543+
ArrayRef<OpFoldResult> mixedInputSgData,
544+
ArrayRef<OpFoldResult> mixedInputInstData,
545+
ArrayRef<OpFoldResult> mixedTargetSgLayout,
546+
ArrayRef<OpFoldResult> mixedTargetSgData,
547+
ArrayRef<OpFoldResult> mixedTargetInstData) {
548+
SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
549+
staticInputInstData;
550+
SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
551+
dynamicInputInstData;
552+
dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
553+
staticInputSgLayout);
554+
dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
555+
staticInputSgData);
556+
dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
557+
staticInputInstData);
558+
SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
559+
staticTargetInstData;
560+
SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
561+
dynamicTargetInstData;
562+
dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
563+
staticTargetSgLayout);
564+
dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
565+
staticTargetSgData);
566+
dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
567+
staticTargetInstData);
577568
build(builder, ostate, target.getType(),
578569
/*target=*/target,
579-
/*sg_layout=*/dynamicSgLayout,
580-
/*sg_data=*/dynamicSgData,
581-
/*inst_data=*/dynamicInstData,
582-
/*static_sg_layout=*/staticSgLayout,
583-
/*static_sg_data=*/staticSgData,
584-
/*static_inst_data=*/staticInstData);
570+
/*input_sg_layout=*/dynamicInputSgLayout,
571+
/*input_sg_data=*/dynamicInputSgData,
572+
/*input_inst_data=*/dynamicInputInstData,
573+
/*target_sg_layout=*/dynamicTargetSgLayout,
574+
/*target_sg_data=*/dynamicTargetSgData,
575+
/*target_inst_data=*/dynamicTargetInstData,
576+
/*static_input_sg_layout=*/staticInputSgLayout,
577+
/*static_input_sg_data=*/staticInputSgData,
578+
/*static_input_inst_data=*/staticInputInstData,
579+
/*static_target_sg_layout=*/staticTargetSgLayout,
580+
/*static_target_sg_data=*/staticTargetSgData,
581+
/*static_target_inst_data=*/staticTargetInstData);
585582
}
586583

587584
DiagnosedSilenceableFailure
@@ -595,18 +592,20 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
595592
<< llvm::range_size(targetValues) << ")";
596593
auto value = *targetValues.begin();
597594

598-
xegpu::LayoutAttr targetLayoutAttr = nullptr;
599-
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
600-
getMixedSgLayout(), getMixedSgData(),
601-
getMixedInstData(), targetLayoutAttr);
595+
// Construct layout attributes.
596+
xegpu::LayoutAttr inputLayoutAttr = nullptr;
597+
auto status = getLayoutAttrFromOperands(
598+
getContext(), state, (*this), getMixedInputSgLayout(),
599+
getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
602600
if (!status.succeeded())
603601
return status;
604602

605-
// Find source layout attribute from the producer chain.
606-
auto producerLayoutAttr = findProducerLayout(value);
607-
if (!producerLayoutAttr)
608-
return emitSilenceableFailure(getLoc())
609-
<< "Could not find a layout attribute in the producer chain.";
603+
xegpu::LayoutAttr targetLayoutAttr = nullptr;
604+
status = getLayoutAttrFromOperands(
605+
getContext(), state, (*this), getMixedTargetSgLayout(),
606+
getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
607+
if (!status.succeeded())
608+
return status;
610609

611610
// Find first user op to define insertion point for layout conversion.
612611
if (value.use_empty())
@@ -616,9 +615,9 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
616615

617616
// Emit convert_layout op.
618617
rewriter.setInsertionPoint(userOp);
619-
auto convLayoutOp = xegpu::ConvertLayoutOp::create(
620-
rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
621-
targetLayoutAttr);
618+
auto convLayoutOp =
619+
xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
620+
value, inputLayoutAttr, targetLayoutAttr);
622621
// Replace load op result with the converted layout.
623622
rewriter.replaceUsesWithIf(
624623
value, convLayoutOp.getResult(), [&](OpOperand &use) {
@@ -632,9 +631,12 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
632631
void transform::ConvertLayoutOp::getEffects(
633632
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
634633
onlyReadsHandle(getTargetMutable(), effects);
635-
onlyReadsHandle(getSgLayoutMutable(), effects);
636-
onlyReadsHandle(getSgDataMutable(), effects);
637-
onlyReadsHandle(getInstDataMutable(), effects);
634+
onlyReadsHandle(getInputSgLayoutMutable(), effects);
635+
onlyReadsHandle(getInputSgDataMutable(), effects);
636+
onlyReadsHandle(getInputInstDataMutable(), effects);
637+
onlyReadsHandle(getTargetSgLayoutMutable(), effects);
638+
onlyReadsHandle(getTargetSgDataMutable(), effects);
639+
onlyReadsHandle(getTargetInstDataMutable(), effects);
638640
producesHandle(getOperation()->getOpResults(), effects);
639641
modifiesPayload(effects);
640642
}

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -271,57 +271,88 @@ class ConvertLayoutOp(ConvertLayoutOp):
271271
def __init__(
272272
self,
273273
target: Value,
274-
sg_layout: MixedValues,
275-
sg_data: MixedValues,
274+
input_sg_layout: MixedValues,
275+
input_sg_data: MixedValues,
276+
target_sg_layout: MixedValues,
277+
target_sg_data: MixedValues,
276278
*,
277-
inst_data: Optional[MixedValues] = None,
279+
input_inst_data: Optional[MixedValues] = None,
280+
target_inst_data: Optional[MixedValues] = None,
278281
loc=None,
279282
ip=None,
280283
):
281-
inst_data = [] if inst_data is None else inst_data
284+
input_inst_data = [] if input_inst_data is None else input_inst_data
285+
target_inst_data = [] if target_inst_data is None else target_inst_data
282286
(
283-
dynamic_sg_layout,
284-
static_sg_layout,
287+
dynamic_input_sg_layout,
288+
static_input_sg_layout,
285289
_,
286-
) = _dispatch_dynamic_index_list(sg_layout)
290+
) = _dispatch_dynamic_index_list(input_sg_layout)
287291
(
288-
dynamic_sg_data,
289-
static_sg_data,
292+
dynamic_input_sg_data,
293+
static_input_sg_data,
290294
_,
291-
) = _dispatch_dynamic_index_list(sg_data)
295+
) = _dispatch_dynamic_index_list(input_sg_data)
292296
(
293-
dynamic_inst_data,
294-
static_inst_data,
297+
dynamic_input_inst_data,
298+
static_input_inst_data,
295299
_,
296-
) = _dispatch_dynamic_index_list(inst_data)
300+
) = _dispatch_dynamic_index_list(input_inst_data)
301+
(
302+
dynamic_target_sg_layout,
303+
static_target_sg_layout,
304+
_,
305+
) = _dispatch_dynamic_index_list(target_sg_layout)
306+
(
307+
dynamic_target_sg_data,
308+
static_target_sg_data,
309+
_,
310+
) = _dispatch_dynamic_index_list(target_sg_data)
311+
(
312+
dynamic_target_inst_data,
313+
static_target_inst_data,
314+
_,
315+
) = _dispatch_dynamic_index_list(target_inst_data)
297316
super().__init__(
298317
transform.AnyOpType.get(),
299318
target,
300-
dynamic_sg_layout,
301-
dynamic_sg_data,
302-
dynamic_inst_data,
303-
static_sg_layout=static_sg_layout,
304-
static_sg_data=static_sg_data,
305-
static_inst_data=static_inst_data,
319+
dynamic_input_sg_layout,
320+
dynamic_input_sg_data,
321+
dynamic_input_inst_data,
322+
dynamic_target_sg_layout,
323+
dynamic_target_sg_data,
324+
dynamic_target_inst_data,
325+
static_input_sg_layout=static_input_sg_layout,
326+
static_input_sg_data=static_input_sg_data,
327+
static_input_inst_data=static_input_inst_data,
328+
static_target_sg_layout=static_target_sg_layout,
329+
static_target_sg_data=static_target_sg_data,
330+
static_target_inst_data=static_target_inst_data,
306331
loc=loc,
307332
ip=ip,
308333
)
309334

310335

311336
def convert_layout(
312337
target: Value,
313-
sg_layout: MixedValues,
314-
sg_data: MixedValues,
338+
input_sg_layout: MixedValues,
339+
input_sg_data: MixedValues,
340+
target_sg_layout: MixedValues,
341+
target_sg_data: MixedValues,
315342
*,
316-
inst_data: Optional[MixedValues] = None,
343+
input_inst_data: Optional[MixedValues] = None,
344+
target_inst_data: Optional[MixedValues] = None,
317345
loc=None,
318346
ip=None,
319347
) -> ConvertLayoutOp:
320348
return ConvertLayoutOp(
321349
target,
322-
sg_layout,
323-
sg_data,
324-
inst_data=inst_data,
350+
input_sg_layout,
351+
input_sg_data,
352+
target_sg_layout,
353+
target_sg_data,
354+
input_inst_data=input_inst_data,
355+
target_inst_data=target_inst_data,
325356
loc=loc,
326357
ip=ip,
327358
).result

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -155,24 +155,3 @@ module attributes {transform.with_named_sequence} {
155155
transform.yield
156156
}
157157
}
158-
159-
// -----
160-
161-
// CHECK-LABEL: @convert_layout_no_producer_attr
162-
func.func @convert_layout_no_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
163-
%c0 = arith.constant 0 : index
164-
%0 = arith.addf %arg0, %arg1 : vector<32x32xf16>
165-
%1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
166-
%2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
167-
return
168-
}
169-
170-
module attributes {transform.with_named_sequence} {
171-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
172-
%0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
173-
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
174-
// expected-error@below {{Could not find a layout attribute in the producer chain.}}
175-
transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
176-
transform.yield
177-
}
178-
}

0 commit comments

Comments
 (0)