Skip to content

Commit 3a68751

Browse files
authored
[MLIR][XeGPU][Transform] add xegpu.set_desc_layout transform op (#165615)
Adds the first XeGPU transform op, `xegpu.set_desc_layout`, which attachs a `xegpu.layout` attribute to the descriptor that a `xegpu.create_nd_tdesc` op returns.
1 parent 9d1b578 commit 3a68751

File tree

16 files changed

+584
-7
lines changed

16 files changed

+584
-7
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2121
include "mlir/IR/OpBase.td"
2222
include "mlir/IR/RegionKindInterface.td"
2323

24-
// This is roughly similar to OpFoldResult assuming the handle produces a single
25-
// value in the payload IR.
26-
def TransformAnyParamTypeOrAnyHandle : Type<
27-
Or<[TransformHandleTypeInterface.predicate,
28-
TransformParamTypeInterface.predicate]>,
29-
"transform any param type or any handle type">;
30-
3124
//===----------------------------------------------------------------------===//
3225
// Apply...PatternsOp
3326
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,9 @@ def TransformAnyHandle : Type<
103103
TransformValueHandleTypeInterface.predicate]>,
104104
"transform operation or value handle">;
105105

106+
def TransformAnyParamTypeOrAnyHandle : Type<
107+
Or<[TransformHandleTypeInterface.predicate,
108+
TransformParamTypeInterface.predicate]>,
109+
"transform any param type or any handle type">;
110+
106111
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
2+
mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
5+
6+
add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
10+
#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
13+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
16+
17+
#define GET_OP_CLASSES
18+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"
19+
20+
namespace mlir {
21+
class DialectRegistry;
22+
23+
namespace xegpu {
24+
void registerTransformDialectExtension(DialectRegistry &registry);
25+
} // namespace xegpu
26+
} // namespace mlir
27+
28+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef XEGPU_TRANSFORM_OPS
10+
#define XEGPU_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
13+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
14+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
15+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
20+
AttrSizedOperandSegments,
21+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
22+
TransformOpInterface
23+
]> {
24+
25+
let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
26+
let description = [{
27+
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
28+
attribute to the result tensor descriptor. The layout is defined by the
29+
`sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
30+
to the transformed op.
31+
}];
32+
33+
let arguments = (ins
34+
TransformHandleTypeInterface : $target,
35+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
36+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
37+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
38+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
39+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
40+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
41+
);
42+
43+
let results = (outs TransformHandleTypeInterface : $transformed);
44+
let builders = [
45+
OpBuilder<(ins "Value":$target,
46+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
47+
"ArrayRef<OpFoldResult>":$mixedSgData,
48+
"ArrayRef<OpFoldResult>":$mixedInstData
49+
)>,
50+
];
51+
52+
let assemblyFormat = [{
53+
$target
54+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
55+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
56+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
57+
attr-dict `:` functional-type(operands, results)
58+
}];
59+
60+
let extraClassDeclaration = [{
61+
::mlir::DiagnosedSilenceableFailure apply(
62+
::mlir::transform::TransformRewriter &rewriter,
63+
::mlir::transform::TransformResults &transformResults,
64+
::mlir::transform::TransformState &state);
65+
66+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
67+
Builder b(getContext());
68+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
69+
}
70+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
71+
Builder b(getContext());
72+
return getMixedValues(getStaticSgData(), getSgData(), b);
73+
}
74+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
75+
Builder b(getContext());
76+
return getMixedValues(getStaticInstData(), getInstData(), b);
77+
}
78+
}];
79+
}
80+
81+
#endif // XEGPU_TRANSFORM_OPS
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
33
add_subdirectory(Utils)
4+
add_subdirectory(TransformOps)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRXeGPUTransformOps
2+
XeGPUTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
6+
7+
DEPENDS
8+
MLIRXeGPUTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRXeGPUDialect
12+
MLIRXeGPUTransforms
13+
MLIRIR
14+
MLIRTransformDialect
15+
MLIRFuncDialect
16+
MLIRSCFDialect
17+
)
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10+
#include "mlir/Dialect/SCF/IR/SCF.h"
11+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
12+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
13+
14+
#include <optional>
15+
16+
using namespace mlir;
17+
using namespace mlir::transform;
18+
19+
/// Assuming that `ofr` is an index attr or a param of index type
20+
/// or a transform dialect handle mapped to exactly one op
21+
/// with one index result, get that value and cast it to int type.
22+
static DiagnosedSilenceableFailure convertMixedValuesToInt(
23+
transform::TransformState &state, TransformOpInterface transformOp,
24+
SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
25+
for (OpFoldResult ofr : ofrs) {
26+
// Attribute case.
27+
if (auto attr = dyn_cast<Attribute>(ofr)) {
28+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
29+
result.push_back(intAttr.getInt());
30+
continue;
31+
}
32+
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
33+
}
34+
35+
// Transform param case.
36+
Value transformValue = cast<Value>(ofr);
37+
if (isa<TransformParamTypeInterface>(transformValue.getType())) {
38+
ArrayRef<Attribute> params = state.getParams(transformValue);
39+
if (params.size() != 1)
40+
return transformOp.emitDefiniteFailure()
41+
<< "requires exactly one parameter associated";
42+
result.push_back(
43+
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
44+
continue;
45+
}
46+
47+
// Payload value case.
48+
auto payloadOps = state.getPayloadOps(transformValue);
49+
if (!llvm::hasSingleElement(payloadOps)) {
50+
DiagnosedSilenceableFailure diag =
51+
transformOp.emitSilenceableError()
52+
<< "handle must be mapped to exactly one payload op";
53+
diag.attachNote(transformValue.getLoc())
54+
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
55+
return diag;
56+
}
57+
58+
Operation *op = *payloadOps.begin();
59+
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
60+
DiagnosedSilenceableFailure diag =
61+
transformOp.emitSilenceableError()
62+
<< "payload op must have exactly 1 index result";
63+
diag.attachNote(op->getLoc())
64+
<< "has " << op->getNumResults() << " results";
65+
return diag;
66+
}
67+
68+
IntegerAttr intAttr;
69+
if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
70+
return transformOp.emitSilenceableError()
71+
<< "requires param or handle to be the result of a constant like "
72+
"op";
73+
74+
result.push_back(intAttr.getInt());
75+
}
76+
return DiagnosedSilenceableFailure::success();
77+
}
78+
79+
/// Create a layout attribute from the given parameters.
80+
static xegpu::LayoutAttr
81+
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
82+
ArrayRef<int32_t> sgData,
83+
std::optional<ArrayRef<int32_t>> instData) {
84+
return xegpu::LayoutAttr::get(
85+
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
86+
DenseI32ArrayAttr::get(ctx, sgData),
87+
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
88+
/*lane_layout=*/nullptr,
89+
/*lane_data=*/nullptr,
90+
/*order=*/nullptr);
91+
}
92+
93+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
94+
static xegpu::CreateNdDescOp
95+
setDescLayout(transform::TransformRewriter &rewriter,
96+
xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
97+
assert(descOp.getMixedOffsets().size() == 0 &&
98+
"create desc op with offsets is not supported");
99+
auto oldTensorDesc = descOp.getType();
100+
auto descType = xegpu::TensorDescType::get(
101+
oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
102+
/*array_length=*/oldTensorDesc.getArrayLength(),
103+
/*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
104+
/*memory_space=*/oldTensorDesc.getMemorySpace(),
105+
/*layout=*/layout);
106+
107+
rewriter.setInsertionPointAfter(descOp);
108+
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
109+
descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
110+
descOp.getMixedStrides());
111+
return newDescOp;
112+
}
113+
114+
void transform::SetDescLayoutOp::build(OpBuilder &builder,
115+
OperationState &result, Value target,
116+
ArrayRef<OpFoldResult> mixedSgLayout,
117+
ArrayRef<OpFoldResult> mixedSgData,
118+
ArrayRef<OpFoldResult> mixedInstData) {
119+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
120+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
121+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
122+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
123+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
124+
build(builder, result, target.getType(),
125+
/*target=*/target,
126+
/*sg_layout=*/dynamicSgLayout,
127+
/*sg_data=*/dynamicSgData,
128+
/*inst_data=*/dynamicInstData,
129+
/*static_sg_layout=*/staticSgLayout,
130+
/*static_sg_data=*/staticSgData,
131+
/*static_inst_data=*/staticInstData);
132+
}
133+
134+
DiagnosedSilenceableFailure
135+
transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
136+
transform::TransformResults &results,
137+
transform::TransformState &state) {
138+
auto targetOps = state.getPayloadOps(getTarget());
139+
if (!llvm::hasSingleElement(targetOps)) {
140+
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
141+
<< llvm::range_size(targetOps) << ")";
142+
}
143+
Operation *target = *targetOps.begin();
144+
145+
SmallVector<int32_t> sgLayout;
146+
DiagnosedSilenceableFailure status =
147+
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
148+
if (!status.succeeded())
149+
return status;
150+
151+
SmallVector<int32_t> sgData;
152+
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
153+
if (!status.succeeded())
154+
return status;
155+
156+
SmallVector<int32_t> instData;
157+
status =
158+
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
159+
if (!status.succeeded())
160+
return status;
161+
auto maybeInstData = instData.empty()
162+
? std::nullopt
163+
: std::optional<ArrayRef<int32_t>>(instData);
164+
165+
// For now only create_nd_desc op is supported.
166+
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
167+
if (!descOp) {
168+
auto diag = emitSilenceableFailure(getLoc())
169+
<< "Expected a xegpu.create_nd_desc op, but got: "
170+
<< target->getName();
171+
diag.attachNote(target->getLoc()) << "target op";
172+
return diag;
173+
}
174+
175+
// Set layout attr in desc op's return type. Replaces old desc op.
176+
auto layoutAttr =
177+
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
178+
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
179+
180+
// Map result handles.
181+
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
182+
183+
return DiagnosedSilenceableFailure::success();
184+
}
185+
186+
void transform::SetDescLayoutOp::getEffects(
187+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
188+
consumesHandle(getTargetMutable(), effects);
189+
onlyReadsHandle(getSgLayoutMutable(), effects);
190+
onlyReadsHandle(getSgDataMutable(), effects);
191+
onlyReadsHandle(getInstDataMutable(), effects);
192+
producesHandle(getOperation()->getOpResults(), effects);
193+
modifiesPayload(effects);
194+
}
195+
196+
namespace {
197+
class XeGPUTransformDialectExtension
198+
: public transform::TransformDialectExtension<
199+
XeGPUTransformDialectExtension> {
200+
public:
201+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
202+
203+
using Base::Base;
204+
205+
void init();
206+
};
207+
208+
void XeGPUTransformDialectExtension::init() {
209+
declareGeneratedDialect<scf::SCFDialect>();
210+
declareGeneratedDialect<arith::ArithDialect>();
211+
declareGeneratedDialect<xegpu::XeGPUDialect>();
212+
213+
registerTransformOps<
214+
#define GET_OP_LIST
215+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
216+
>();
217+
}
218+
} // namespace
219+
220+
#define GET_OP_CLASSES
221+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
222+
223+
void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
224+
registry.addExtensions<XeGPUTransformDialectExtension>();
225+
}

0 commit comments

Comments
 (0)