Skip to content

Commit 0c82c93

Browse files
committed
address Adam's comments
1 parent 1bbe829 commit 0c82c93

File tree

7 files changed

+96
-42
lines changed

7 files changed

+96
-42
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1616

1717
#define GET_OP_CLASSES
18-
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
18+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"
1919

2020
namespace mlir {
2121
class DialectRegistry;

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef XEGPU_EXTENSION
10-
#define XEGPU_EXTENSION
9+
#ifndef XEGPU_TRANSFORM_OPS
10+
#define XEGPU_TRANSFORM_OPS
1111

1212
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1313
include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -27,11 +27,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
2727
TransformOpInterface
2828
]> {
2929

30-
let summary = "Set xegpu.layout attribute to an xegpu op result.";
30+
let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
3131
let description = [{
3232
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
3333
attribute to the result tensor descriptor. The layout is defined by the
34-
`sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
34+
`sg_layout`, and optional `sg_data` and `inst_data` attributes. Returns a handle
35+
to the transformed op.
3536
}];
3637

3738
let arguments = (ins
@@ -56,8 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
5657
let assemblyFormat = [{
5758
$target
5859
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
59-
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
60-
`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
60+
(`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)^)?
61+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
6162
attr-dict `:` functional-type(operands, results)
6263
}];
6364

@@ -82,4 +83,4 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
8283
}];
8384
}
8485

85-
#endif // XEGPU_EXTENSION
86+
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,10 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10-
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
11-
#include "mlir/Dialect/Func/IR/FuncOps.h"
1210
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14-
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1511
#include "mlir/Dialect/SCF/IR/SCF.h"
16-
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17-
#include "mlir/Dialect/SCF/Utils/Utils.h"
18-
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
19-
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
20-
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21-
#include "mlir/Dialect/Transform/Utils/Utils.h"
2212
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
2313
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
24-
#include "mlir/IR/DialectRegistry.h"
25-
#include "mlir/IR/Operation.h"
26-
#include "mlir/Interfaces/SideEffectInterfaces.h"
27-
#include "mlir/Support/LLVM.h"
28-
#include "llvm/ADT/SmallVector.h"
29-
#include "llvm/ADT/StringRef.h"
3014

3115
#include <numeric>
3216

@@ -129,11 +113,11 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
129113

130114
/// Create a layout attribute from the given parameters.
131115
xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
132-
ArrayRef<int32_t> sgData,
116+
std::optional<ArrayRef<int32_t>> sgData,
133117
std::optional<ArrayRef<int32_t>> instData) {
134118
return xegpu::LayoutAttr::get(
135119
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
136-
DenseI32ArrayAttr::get(ctx, sgData),
120+
sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
137121
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
138122
/*lane_layout=*/nullptr,
139123
/*lane_data=*/nullptr,
@@ -211,12 +195,17 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
211195
convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
212196
if (!status.succeeded())
213197
return status;
198+
auto maybeSgData =
199+
sgData.empty() ? std::nullopt : std::optional<ArrayRef<int32_t>>(sgData);
214200

215201
SmallVector<int32_t> instData;
216202
status =
217203
convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
218204
if (!status.succeeded())
219205
return status;
206+
auto maybeInstData = instData.empty()
207+
? std::nullopt
208+
: std::optional<ArrayRef<int32_t>>(instData);
220209

221210
// For now only create_nd_desc op is supported.
222211
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
@@ -229,8 +218,8 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
229218
}
230219

231220
// Set layout attr in desc op's return type. Replaces old desc op.
232-
auto layoutAttr =
233-
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
221+
auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout,
222+
maybeSgData, maybeInstData);
234223
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
235224

236225
// Map result handles.

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ def __init__(
2828
self,
2929
target: Union[Operation, Value],
3030
sg_layout: MixedValues,
31-
sg_data: MixedValues,
32-
inst_data: MixedValues,
3331
*,
32+
sg_data: MixedValues = None,
33+
inst_data: MixedValues = None,
3434
loc=None,
3535
ip=None,
3636
):
3737
target_value = _get_op_result_or_value(target)
38+
sg_data = [] if sg_data is None else sg_data
39+
inst_data = [] if inst_data is None else inst_data
3840
(
3941
dynamic_sg_layout,
4042
static_sg_layout,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
2+
3+
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
4+
%c32 = arith.constant 32 : index // expected-note {{target op}}
5+
return
6+
}
7+
8+
module attributes {transform.with_named_sequence} {
9+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
10+
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
11+
// expected-error@below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
12+
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
13+
transform.yield
14+
}
15+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,44 @@ module attributes {transform.with_named_sequence} {
1919

2020
// -----
2121

22+
// CHECK-LABEL: @set_desc_layout_minimal
23+
func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
24+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
25+
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4]
26+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
27+
return
28+
}
29+
30+
module attributes {transform.with_named_sequence} {
31+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
32+
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
33+
// CHECK: transform.xegpu.set_desc_layout %{{.*}}
34+
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
35+
transform.yield
36+
}
37+
}
38+
39+
// -----
40+
41+
// CHECK-LABEL: @set_desc_layout_sg_data
42+
func.func @set_desc_layout_sg_data(%arg0: memref<4096x4096xf16>) {
43+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
44+
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
45+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
46+
return
47+
}
48+
49+
module attributes {transform.with_named_sequence} {
50+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
51+
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
52+
// CHECK: transform.xegpu.set_desc_layout %{{.*}}
53+
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
54+
transform.yield
55+
}
56+
}
57+
58+
// -----
59+
2260
// CHECK-LABEL: @set_desc_layout_param
2361
func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
2462
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,47 @@ def run(f):
1717

1818

1919
@run
20-
def setDescLayout():
20+
def setDescLayoutMinimal():
2121
sequence = transform.SequenceOp(
2222
transform.FailurePropagationMode.Propagate,
2323
[],
2424
transform.OperationType.get("xegpu.create_nd_tdesc"),
2525
)
2626
with InsertionPoint(sequence.body):
27-
xegpu.SetDescLayoutOp(
28-
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
29-
)
27+
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4])
3028
transform.YieldOp()
31-
# CHECK-LABEL: TEST: setDescLayout
29+
# CHECK-LABEL: TEST: setDescLayoutMinimal
3230
# CHECK: %0 = transform.xegpu.set_desc_layout %
3331
# CHECK: sg_layout = [6, 4]
34-
# CHECK: sg_data = [32, 16]
35-
# CHECK: inst_data = [8, 16]
3632

3733

3834
@run
39-
def setDescLayoutDefaultIndex():
35+
def setDescLayoutSgData():
4036
sequence = transform.SequenceOp(
4137
transform.FailurePropagationMode.Propagate,
4238
[],
4339
transform.OperationType.get("xegpu.create_nd_tdesc"),
4440
)
4541
with InsertionPoint(sequence.body):
46-
xegpu.SetDescLayoutOp(
47-
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
48-
)
42+
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
4943
transform.YieldOp()
50-
# CHECK-LABEL: TEST: setDescLayoutDefaultIndex
44+
# CHECK-LABEL: TEST: setDescLayoutSgData
5145
# CHECK: %0 = transform.xegpu.set_desc_layout %
5246
# CHECK: sg_layout = [6, 4]
5347
# CHECK: sg_data = [32, 16]
48+
49+
50+
@run
51+
def setDescLayoutInstData():
52+
sequence = transform.SequenceOp(
53+
transform.FailurePropagationMode.Propagate,
54+
[],
55+
transform.OperationType.get("xegpu.create_nd_tdesc"),
56+
)
57+
with InsertionPoint(sequence.body):
58+
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], inst_data=[8, 16])
59+
transform.YieldOp()
60+
# CHECK-LABEL: TEST: setDescLayoutInstData
61+
# CHECK: %0 = transform.xegpu.set_desc_layout %
62+
# CHECK: sg_layout = [6, 4]
5463
# CHECK: inst_data = [8, 16]

0 commit comments

Comments
 (0)