Skip to content

Commit 1553f90

Browse files
authored
[MLIR][XeGPU][TransformOps] Add get_desc_op (llvm#166801)
Add `transform.xegpu.get_desc_op` transform op that finds a `xegpu.create_nd_tdesc` producer op of a `Value`.
1 parent a5d4ba7 commit 1553f90

File tree

5 files changed

+187
-6
lines changed

5 files changed

+187
-6
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/IR/OpBase.td"
1818

19+
def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
20+
DeclareOpInterfaceMethods<TransformOpInterface>,
21+
NavigationTransformOpTrait, MemoryEffectsOpInterface
22+
]> {
23+
24+
let summary = "Get a handle to the descriptor op of a value.";
25+
let description = [{
26+
Traces the producers of the given value until an `xegpu.create_nd_tdesc`
27+
descriptor op is found. Returns a handle to it. Currently traces
28+
producers by following only the first operand of producer ops.
29+
}];
30+
31+
let arguments = (ins TransformValueHandleTypeInterface:$target);
32+
33+
let results = (outs TransformHandleTypeInterface:$descHandle);
34+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
35+
}
36+
1937
def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
2038
AttrSizedOperandSegments,
2139
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -31,16 +49,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
3149
}];
3250

3351
let arguments = (ins
34-
TransformHandleTypeInterface : $target,
35-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
36-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
37-
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
52+
TransformHandleTypeInterface:$target,
53+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
54+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
55+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
3856
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
3957
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
4058
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
4159
);
4260

43-
let results = (outs TransformHandleTypeInterface : $transformed);
61+
let results = (outs TransformHandleTypeInterface:$transformed);
4462
let builders = [
4563
OpBuilder<(ins "Value":$target,
4664
"ArrayRef<OpFoldResult>":$mixedSgLayout,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
#include <optional>
1515

16+
#include "llvm/Support/DebugLog.h"
17+
#define DEBUG_TYPE "xegpu-transforms"
18+
1619
using namespace mlir;
1720
using namespace mlir::transform;
1821

@@ -76,6 +79,45 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
7679
return DiagnosedSilenceableFailure::success();
7780
}
7881

82+
/// Find producer operation of type T for the given value.
83+
/// It's assumed that producer ops are chained through their first operand.
84+
/// Producer chain is traced trough loop block arguments (init values).
85+
template <typename T>
86+
static std::optional<T> findProducerOfType(Value val) {
87+
Value currentValue = val;
88+
if (!currentValue.getDefiningOp()) {
89+
// Value may be a block argument initialized outside a loop.
90+
if (val.getNumUses() == 0) {
91+
LDBG() << "Failed to find producer op, value has no uses.";
92+
return std::nullopt;
93+
}
94+
auto userOp = val.getUsers().begin();
95+
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
96+
if (!parentLoop) {
97+
LDBG() << "Failed to find producer op, not in a loop.";
98+
return std::nullopt;
99+
}
100+
int64_t iterArgIdx;
101+
if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
102+
auto numInductionVars = parentLoop.getLoopInductionVars()->size();
103+
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
104+
currentValue = parentLoop.getInits()[iterArgIdx];
105+
} else {
106+
LDBG() << "Failed to find producer op, value not in init values.";
107+
return std::nullopt;
108+
}
109+
}
110+
Operation *producerOp = currentValue.getDefiningOp();
111+
112+
if (auto matchingOp = dyn_cast<T>(producerOp))
113+
return matchingOp;
114+
115+
if (producerOp->getNumOperands() == 0)
116+
return std::nullopt;
117+
118+
return findProducerOfType<T>(producerOp->getOperand(0));
119+
}
120+
79121
/// Create a layout attribute from the given parameters.
80122
static xegpu::LayoutAttr
81123
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +153,29 @@ setDescLayout(transform::TransformRewriter &rewriter,
111153
return newDescOp;
112154
}
113155

156+
DiagnosedSilenceableFailure
157+
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
158+
transform::TransformResults &results,
159+
transform::TransformState &state) {
160+
auto targetValues = state.getPayloadValues(getTarget());
161+
if (!llvm::hasSingleElement(targetValues)) {
162+
return emitDefiniteFailure()
163+
<< "requires exactly one target value handle (got "
164+
<< llvm::range_size(targetValues) << ")";
165+
}
166+
167+
auto maybeDescOp =
168+
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
169+
if (!maybeDescOp) {
170+
return emitSilenceableFailure(getLoc())
171+
<< "Could not find a matching descriptor op when walking the "
172+
"producer chain of the first operand.";
173+
}
174+
175+
results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
176+
return DiagnosedSilenceableFailure::success();
177+
}
178+
114179
void transform::SetDescLayoutOp::build(OpBuilder &builder,
115180
OperationState &result, Value target,
116181
ArrayRef<OpFoldResult> mixedSgLayout,

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
try:
99
from ...ir import *
10+
from ...dialects import transform
1011
from .._ods_common import _cext as _ods_cext
1112
from .._ods_common import (
1213
MixedValues,
@@ -20,6 +21,26 @@
2021
from typing import Union, Optional
2122

2223

24+
@_ods_cext.register_operation(_Dialect, replace=True)
25+
class GetDescOp(GetDescOp):
26+
"""Specialization for GetDescOp class."""
27+
28+
def __init__(
29+
self,
30+
target: Value,
31+
*,
32+
loc=None,
33+
ip=None,
34+
):
35+
desc_type = transform.AnyOpType.get()
36+
super().__init__(
37+
desc_type,
38+
target,
39+
loc=loc,
40+
ip=ip,
41+
)
42+
43+
2344
@_ods_cext.register_operation(_Dialect, replace=True)
2445
class SetDescLayoutOp(SetDescLayoutOp):
2546
"""Specialization for SetDescLayoutOp class."""

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,67 @@
11
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
22

3+
// CHECK-LABEL: @get_desc_op_a
4+
func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
5+
%c32 = arith.constant 32 : index
6+
%c4096 = arith.constant 4096 : index
7+
%c0 = arith.constant 0 : index
8+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
9+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
10+
// expected-remark @below {{found desc op}}
11+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
12+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
13+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
14+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
15+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
16+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
17+
scf.yield %7 : vector<256x256xf16>
18+
}
19+
return
20+
}
21+
22+
module attributes {transform.with_named_sequence} {
23+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
24+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
25+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
26+
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
27+
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
28+
transform.yield
29+
}
30+
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: @get_desc_op_c
35+
func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
36+
%c32 = arith.constant 32 : index
37+
%c4096 = arith.constant 4096 : index
38+
%c0 = arith.constant 0 : index
39+
// expected-remark @below {{found desc op}}
40+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
41+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
42+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
43+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
44+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
45+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
46+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
47+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
48+
scf.yield %7 : vector<256x256xf16>
49+
}
50+
return
51+
}
52+
53+
module attributes {transform.with_named_sequence} {
54+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
57+
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
58+
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
59+
transform.yield
60+
}
61+
}
62+
63+
// -----
64+
365
// CHECK-LABEL: @set_desc_layout
466
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
567
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mlir.ir import *
44
from mlir.dialects import transform
55
from mlir.dialects.transform import xegpu
6-
from mlir.dialects.transform import structured
6+
from mlir.dialects.transform import AnyValueType
77

88

99
def run(f):
@@ -16,6 +16,21 @@ def run(f):
1616
return f
1717

1818

19+
@run
20+
def getDescOpDefaultIndex():
21+
sequence = transform.SequenceOp(
22+
transform.FailurePropagationMode.Propagate,
23+
[],
24+
transform.OperationType.get("xegpu.dpas"),
25+
)
26+
with InsertionPoint(sequence.body):
27+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
28+
desc_handle = xegpu.GetDescOp(operand)
29+
transform.YieldOp()
30+
# CHECK-LABEL: TEST: getDescOpDefaultIndex
31+
# CHECK: transform.xegpu.get_desc_op %
32+
33+
1934
@run
2035
def setDescLayoutMinimal():
2136
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)