Skip to content

Commit 44a5e63

Browse files
committed
[mlir][xegpu] add xegpu.set_desc_layout transform op
1 parent 2dca188 commit 44a5e63

File tree

13 files changed

+581
-0
lines changed

13 files changed

+581
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
2+
mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
5+
6+
add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
10+
#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
13+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
16+
17+
#define GET_OP_CLASSES
18+
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
19+
20+
namespace mlir {
21+
class DialectRegistry;
22+
23+
namespace xegpu {
24+
void registerTransformDialectExtension(DialectRegistry &registry);
25+
} // namespace xegpu
26+
} // namespace mlir
27+
28+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef XEGPU_EXTENSION
10+
#define XEGPU_EXTENSION
11+
12+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
13+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
14+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
15+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
def TransformAnyParamTypeOrAnyHandle : Type<
20+
Or<[TransformHandleTypeInterface.predicate,
21+
TransformParamTypeInterface.predicate]>,
22+
"transform any param type or any handle type">;
23+
24+
def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
25+
AttrSizedOperandSegments,
26+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
27+
TransformOpInterface
28+
]> {
29+
30+
let summary = "Set xegpu.layout attribute to an xegpu op result.";
31+
let description = [{
32+
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
33+
attribute to the result tensor descriptor. The layout is defined by the
34+
`sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
35+
}];
36+
37+
let arguments = (ins
38+
TransformHandleTypeInterface : $target,
39+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
40+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
41+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
42+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
43+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
44+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
45+
);
46+
47+
let results = (outs TransformHandleTypeInterface : $transformed);
48+
let builders = [
49+
OpBuilder<(ins "Value":$target,
50+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
51+
"ArrayRef<OpFoldResult>":$mixedSgData,
52+
"ArrayRef<OpFoldResult>":$mixedInstData
53+
)>,
54+
];
55+
56+
let assemblyFormat = [{
57+
$target
58+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
59+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
60+
`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
61+
attr-dict `:` functional-type(operands, results)
62+
}];
63+
64+
let extraClassDeclaration = [{
65+
::mlir::DiagnosedSilenceableFailure apply(
66+
::mlir::transform::TransformRewriter &rewriter,
67+
::mlir::transform::TransformResults &transformResults,
68+
::mlir::transform::TransformState &state);
69+
70+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
71+
Builder b(getContext());
72+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
73+
}
74+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
75+
Builder b(getContext());
76+
return getMixedValues(getStaticSgData(), getSgData(), b);
77+
}
78+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
79+
Builder b(getContext());
80+
return getMixedValues(getStaticInstData(), getInstData(), b);
81+
}
82+
}];
83+
}
84+
85+
#endif // XEGPU_EXTENSION
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
33
add_subdirectory(Utils)
4+
add_subdirectory(TransformOps)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRXeGPUTransformOps
2+
XeGPUTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
6+
7+
DEPENDS
8+
MLIRXeGPUTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRXeGPUDialect
12+
MLIRXeGPUTransforms
13+
MLIRIR
14+
MLIRTransformDialect
15+
MLIRFuncDialect
16+
MLIRSCFDialect
17+
)
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17+
#include "mlir/Dialect/SCF/Utils/Utils.h"
18+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
19+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
20+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21+
#include "mlir/Dialect/Transform/Utils/Utils.h"
22+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
23+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
24+
#include "mlir/IR/DialectRegistry.h"
25+
#include "mlir/IR/Operation.h"
26+
#include "mlir/Interfaces/SideEffectInterfaces.h"
27+
#include "mlir/Support/LLVM.h"
28+
#include "llvm/ADT/SmallVector.h"
29+
#include "llvm/ADT/StringRef.h"
30+
31+
#include <numeric>
32+
33+
#include "llvm/Support/Debug.h"
34+
#define DEBUG_TYPE "xegpu-transforms"
35+
36+
using namespace mlir;
37+
using namespace mlir::transform;
38+
39+
class XeGPUTransformDialectExtension
40+
: public transform::TransformDialectExtension<
41+
XeGPUTransformDialectExtension> {
42+
public:
43+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
44+
45+
using Base::Base;
46+
47+
void init();
48+
};
49+
50+
void XeGPUTransformDialectExtension::init() {
51+
declareGeneratedDialect<scf::SCFDialect>();
52+
declareGeneratedDialect<arith::ArithDialect>();
53+
declareGeneratedDialect<gpu::GPUDialect>();
54+
declareGeneratedDialect<xegpu::XeGPUDialect>();
55+
56+
registerTransformOps<
57+
#define GET_OP_LIST
58+
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
59+
>();
60+
}
61+
62+
#define GET_OP_CLASSES
63+
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
64+
65+
void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
66+
registry.addExtensions<XeGPUTransformDialectExtension>();
67+
}
68+
69+
/// Assuming that `ofr` is an index attr or a param of index type
70+
/// or a transform dialect handle mapped to exactly one op
71+
/// with one index result, get that value and cast it to int type.
72+
static DiagnosedSilenceableFailure convertMixedValuesToInt(
73+
transform::TransformState &state, TransformOpInterface transformOp,
74+
SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
75+
for (OpFoldResult ofr : ofrs) {
76+
// Attribute case.
77+
if (auto attr = dyn_cast<Attribute>(ofr)) {
78+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
79+
result.push_back(intAttr.getInt());
80+
} else {
81+
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
82+
}
83+
continue;
84+
}
85+
86+
// Transform param case.
87+
Value transformValue = cast<Value>(ofr);
88+
if (isa<TransformParamTypeInterface>(transformValue.getType())) {
89+
ArrayRef<Attribute> params = state.getParams(transformValue);
90+
if (params.size() != 1)
91+
return transformOp.emitDefiniteFailure()
92+
<< "requires exactly one parameter associated";
93+
result.push_back(
94+
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
95+
continue;
96+
}
97+
98+
// Payload value case.
99+
auto payloadOps = state.getPayloadOps(transformValue);
100+
if (!llvm::hasSingleElement(payloadOps)) {
101+
DiagnosedSilenceableFailure diag =
102+
transformOp.emitSilenceableError()
103+
<< "handle must be mapped to exactly one payload op";
104+
diag.attachNote(transformValue.getLoc())
105+
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
106+
return diag;
107+
}
108+
109+
Operation *op = *payloadOps.begin();
110+
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
111+
DiagnosedSilenceableFailure diag =
112+
transformOp.emitSilenceableError()
113+
<< "payload op must have exactly 1 index result";
114+
diag.attachNote(op->getLoc())
115+
<< "has " << op->getNumResults() << " results";
116+
return diag;
117+
}
118+
119+
IntegerAttr intAttr;
120+
if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
121+
return transformOp.emitSilenceableError()
122+
<< "requires param or handle to be the result of a constant like "
123+
"op";
124+
125+
result.push_back(intAttr.getInt());
126+
}
127+
return DiagnosedSilenceableFailure::success();
128+
}
129+
130+
/// Create a layout attribute from the given parameters.
131+
xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
132+
ArrayRef<int32_t> sgData,
133+
std::optional<ArrayRef<int32_t>> instData) {
134+
return xegpu::LayoutAttr::get(
135+
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
136+
DenseI32ArrayAttr::get(ctx, sgData),
137+
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
138+
/*lane_layout=*/nullptr,
139+
/*lane_data=*/nullptr,
140+
/*order=*/nullptr);
141+
}
142+
143+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
144+
xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
145+
xegpu::CreateNdDescOp descOp,
146+
xegpu::LayoutAttr layout) {
147+
auto oldTensorDesc = descOp.getResult();
148+
auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
149+
auto descType = xegpu::TensorDescType::get(
150+
descShapedType.getShape(), descShapedType.getElementType(),
151+
/*array_length=*/1,
152+
/*boundary_check=*/true,
153+
/*memory_space=*/xegpu::MemorySpace::Global,
154+
/*layout=*/layout);
155+
156+
rewriter.setInsertionPointAfter(descOp);
157+
if (descOp.getMixedOffsets().size() > 0) {
158+
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
159+
descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
160+
descOp.getMixedSizes(), descOp.getMixedStrides());
161+
return newDescOp;
162+
}
163+
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
164+
descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
165+
descOp.getMixedStrides());
166+
return newDescOp;
167+
}
168+
169+
void transform::SetDescLayoutOp::build(OpBuilder &builder,
170+
OperationState &result, Value target,
171+
ArrayRef<OpFoldResult> mixedSgLayout,
172+
ArrayRef<OpFoldResult> mixedSgData,
173+
ArrayRef<OpFoldResult> mixedInstData) {
174+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
175+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
176+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
177+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
178+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
179+
build(builder, result, target.getType(),
180+
/*target=*/target,
181+
/*sg_layout=*/dynamicSgLayout,
182+
/*sg_data=*/dynamicSgData,
183+
/*inst_data=*/dynamicInstData,
184+
/*static_sg_layout=*/staticSgLayout,
185+
/*static_sg_data=*/staticSgData,
186+
/*static_inst_data=*/staticInstData);
187+
}
188+
189+
DiagnosedSilenceableFailure
190+
transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
191+
transform::TransformResults &results,
192+
transform::TransformState &state) {
193+
194+
auto targetOps = state.getPayloadOps(getTarget());
195+
if (!llvm::hasSingleElement(targetOps)) {
196+
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
197+
<< llvm::range_size(targetOps) << ")";
198+
}
199+
Operation *target = *targetOps.begin();
200+
201+
auto transformOp = cast<TransformOpInterface>(getOperation());
202+
203+
SmallVector<int32_t> sgLayout;
204+
DiagnosedSilenceableFailure status =
205+
convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
206+
if (!status.succeeded())
207+
return status;
208+
209+
SmallVector<int32_t> sgData;
210+
status =
211+
convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
212+
if (!status.succeeded())
213+
return status;
214+
215+
SmallVector<int32_t> instData;
216+
status =
217+
convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
218+
if (!status.succeeded())
219+
return status;
220+
221+
// For now only create_nd_desc op is supported.
222+
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
223+
if (!descOp) {
224+
auto diag = emitSilenceableFailure(getLoc())
225+
<< "Expected a xegpu.create_nd_desc op, but got: "
226+
<< target->getName();
227+
diag.attachNote(target->getLoc()) << "target op";
228+
return diag;
229+
}
230+
231+
// Set layout attr in desc op's return type. Replaces old desc op.
232+
auto layoutAttr =
233+
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
234+
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
235+
236+
// Map result handles.
237+
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
238+
239+
return DiagnosedSilenceableFailure::success();
240+
}
241+
242+
void transform::SetDescLayoutOp::getEffects(
243+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
244+
consumesHandle(getTargetMutable(), effects);
245+
onlyReadsHandle(getSgLayoutMutable(), effects);
246+
onlyReadsHandle(getSgDataMutable(), effects);
247+
onlyReadsHandle(getInstDataMutable(), effects);
248+
producesHandle(getOperation()->getOpResults(), effects);
249+
modifiesPayload(effects);
250+
}

0 commit comments

Comments
 (0)