Skip to content

Commit 8897891

Browse files
committed
[mlir][xegpu][transformops] add get_desc_op
1 parent bda7289 commit 8897891

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/IR/OpBase.td"
1818

19+
def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
20+
DeclareOpInterfaceMethods<TransformOpInterface>,
21+
NavigationTransformOpTrait, MemoryEffectsOpInterface
22+
]> {
23+
24+
let summary = "Get a handle to the descriptor op of a value.";
25+
let description = [{
26+
Traces the producers of the given value until an `xegpu.create_nd_tdesc`
27+
descriptor op is found. Returns a handle to it.
28+
}];
29+
30+
let arguments = (ins TransformValueHandleTypeInterface : $target);
31+
32+
let results = (outs TransformHandleTypeInterface : $descHandle);
33+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
34+
}
35+
1936
def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
2037
AttrSizedOperandSegments,
2138
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
#include <optional>
1515

16+
#include "llvm/Support/Debug.h"
17+
#define DEBUG_TYPE "xegpu-transforms"
18+
1619
using namespace mlir;
1720
using namespace mlir::transform;
1821

@@ -76,6 +79,47 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
7679
return DiagnosedSilenceableFailure::success();
7780
}
7881

82+
/// Find producer operation of type T for the given value.
83+
/// It's assumed that producer ops are chained through their first operand.
84+
/// Producer chain is traced trough loop block arguments (init values).
85+
template <typename T>
86+
static std::optional<T> findProducerOfType(Value val) {
87+
Value currentValue = val;
88+
if (!currentValue.getDefiningOp()) {
89+
// Value may be a block argument initialized outside a loop.
90+
if (val.getNumUses() == 0) {
91+
LLVM_DEBUG(llvm::dbgs()
92+
<< "Failed to find producer op, value has no uses.");
93+
return std::nullopt;
94+
}
95+
auto userOp = val.getUsers().begin();
96+
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
97+
if (!parentLoop) {
98+
LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
99+
return std::nullopt;
100+
}
101+
int64_t iterArgIdx;
102+
if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
103+
auto numInductionVars = parentLoop.getLoopInductionVars()->size();
104+
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
105+
currentValue = parentLoop.getInits()[iterArgIdx];
106+
} else {
107+
LLVM_DEBUG(llvm::dbgs()
108+
<< "Failed to find producer op, value not in init values.");
109+
return std::nullopt;
110+
}
111+
}
112+
Operation *producerOp = currentValue.getDefiningOp();
113+
114+
if (auto matchingOp = dyn_cast<T>(producerOp))
115+
return matchingOp;
116+
117+
if (producerOp->getNumOperands() == 0)
118+
return std::nullopt;
119+
120+
return findProducerOfType<T>(producerOp->getOperand(0));
121+
}
122+
79123
/// Create a layout attribute from the given parameters.
80124
static xegpu::LayoutAttr
81125
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +155,28 @@ setDescLayout(transform::TransformRewriter &rewriter,
111155
return newDescOp;
112156
}
113157

158+
DiagnosedSilenceableFailure
159+
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
160+
transform::TransformResults &results,
161+
transform::TransformState &state) {
162+
163+
auto targetValues = state.getPayloadValues(getTarget());
164+
if (!llvm::hasSingleElement(targetValues)) {
165+
return emitDefiniteFailure()
166+
<< "requires exactly one target value handle (got "
167+
<< llvm::range_size(targetValues) << ")";
168+
}
169+
170+
auto maybeDescOp =
171+
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
172+
if (!maybeDescOp) {
173+
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
174+
}
175+
176+
results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
177+
return DiagnosedSilenceableFailure::success();
178+
}
179+
114180
void transform::SetDescLayoutOp::build(OpBuilder &builder,
115181
OperationState &result, Value target,
116182
ArrayRef<OpFoldResult> mixedSgLayout,

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
try:
99
from ...ir import *
10+
from ...dialects import transform
1011
from .._ods_common import _cext as _ods_cext
1112
from .._ods_common import (
1213
MixedValues,
@@ -20,6 +21,26 @@
2021
from typing import Union, Optional
2122

2223

24+
@_ods_cext.register_operation(_Dialect, replace=True)
25+
class GetDescOp(GetDescOp):
26+
"""Specialization for GetDescOp class."""
27+
28+
def __init__(
29+
self,
30+
target: Value,
31+
*,
32+
loc=None,
33+
ip=None,
34+
):
35+
desc_type = transform.AnyOpType.get()
36+
super().__init__(
37+
desc_type,
38+
target,
39+
loc=loc,
40+
ip=ip,
41+
)
42+
43+
2344
@_ods_cext.register_operation(_Dialect, replace=True)
2445
class SetDescLayoutOp(SetDescLayoutOp):
2546
"""Specialization for SetDescLayoutOp class."""

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
11
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
22

3+
// CHECK-LABEL: @get_desc_op
4+
func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
5+
%c0 = arith.constant 0 : index
6+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
7+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
8+
// expected-remark @below {{found desc op}}
9+
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
10+
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
11+
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
12+
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
13+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
14+
return
15+
}
16+
17+
module attributes {transform.with_named_sequence} {
18+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
19+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20+
%1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
21+
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
22+
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
23+
transform.yield
24+
}
25+
}
26+
27+
// -----
328
// CHECK-LABEL: @set_desc_layout
429
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
530
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mlir.ir import *
44
from mlir.dialects import transform
55
from mlir.dialects.transform import xegpu
6-
from mlir.dialects.transform import structured
6+
from mlir.dialects.transform import AnyValueType
77

88

99
def run(f):
@@ -16,6 +16,21 @@ def run(f):
1616
return f
1717

1818

19+
@run
20+
def getDescOpDefaultIndex():
21+
sequence = transform.SequenceOp(
22+
transform.FailurePropagationMode.Propagate,
23+
[],
24+
transform.OperationType.get("xegpu.dpas"),
25+
)
26+
with InsertionPoint(sequence.body):
27+
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
28+
desc_handle = xegpu.GetDescOp(operand)
29+
transform.YieldOp()
30+
# CHECK-LABEL: TEST: getDescOpDefaultIndex
31+
# CHECK: transform.xegpu.get_desc_op %
32+
33+
1934
@run
2035
def setDescLayoutMinimal():
2136
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)