Skip to content

Commit a759aa6

Browse files
committed
address review comments
1 parent 8897891 commit a759aa6

File tree

3 files changed

+63
-26
lines changed

3 files changed

+63
-26
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
2424
let summary = "Get a handle to the descriptor op of a value.";
2525
let description = [{
2626
Traces the producers of the given value until an `xegpu.create_nd_tdesc`
27-
descriptor op is found. Returns a handle to it.
27+
descriptor op is found. Returns a handle to it. Currently traces
28+
producers by following only the first operand of producer ops.
2829
}];
2930

30-
let arguments = (ins TransformValueHandleTypeInterface : $target);
31+
let arguments = (ins TransformValueHandleTypeInterface:$target);
3132

32-
let results = (outs TransformHandleTypeInterface : $descHandle);
33+
let results = (outs TransformHandleTypeInterface:$descHandle);
3334
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
3435
}
3536

@@ -48,16 +49,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
4849
}];
4950

5051
let arguments = (ins
51-
TransformHandleTypeInterface : $target,
52-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
53-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
54-
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
52+
TransformHandleTypeInterface:$target,
53+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
54+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
55+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
5556
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
5657
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
5758
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
5859
);
5960

60-
let results = (outs TransformHandleTypeInterface : $transformed);
61+
let results = (outs TransformHandleTypeInterface:$transformed);
6162
let builders = [
6263
OpBuilder<(ins "Value":$target,
6364
"ArrayRef<OpFoldResult>":$mixedSgLayout,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
#include <optional>
1515

16-
#include "llvm/Support/Debug.h"
16+
#include "llvm/Support/DebugLog.h"
1717
#define DEBUG_TYPE "xegpu-transforms"
1818

1919
using namespace mlir;
@@ -88,14 +88,13 @@ static std::optional<T> findProducerOfType(Value val) {
8888
if (!currentValue.getDefiningOp()) {
8989
// Value may be a block argument initialized outside a loop.
9090
if (val.getNumUses() == 0) {
91-
LLVM_DEBUG(llvm::dbgs()
92-
<< "Failed to find producer op, value has no uses.");
91+
LDBG() << "Failed to find producer op, value has no uses.";
9392
return std::nullopt;
9493
}
9594
auto userOp = val.getUsers().begin();
9695
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
9796
if (!parentLoop) {
98-
LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
97+
LDBG() << "Failed to find producer op, not in a loop.";
9998
return std::nullopt;
10099
}
101100
int64_t iterArgIdx;
@@ -104,8 +103,7 @@ static std::optional<T> findProducerOfType(Value val) {
104103
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
105104
currentValue = parentLoop.getInits()[iterArgIdx];
106105
} else {
107-
LLVM_DEBUG(llvm::dbgs()
108-
<< "Failed to find producer op, value not in init values.");
106+
LDBG() << "Failed to find producer op, value not in init values.";
109107
return std::nullopt;
110108
}
111109
}
@@ -159,7 +157,6 @@ DiagnosedSilenceableFailure
159157
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
160158
transform::TransformResults &results,
161159
transform::TransformState &state) {
162-
163160
auto targetValues = state.getPayloadValues(getTarget());
164161
if (!llvm::hasSingleElement(targetValues)) {
165162
return emitDefiniteFailure()
@@ -170,7 +167,9 @@ transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
170167
auto maybeDescOp =
171168
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
172169
if (!maybeDescOp) {
173-
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
170+
return emitSilenceableFailure(getLoc())
171+
<< "Could not find a matching descriptor op when walking the "
172+
"producer chain of the first operand.";
174173
}
175174

176175
results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,67 @@
11
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
22

3-
// CHECK-LABEL: @get_desc_op
4-
func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
3+
// CHECK-LABEL: @get_desc_op_a
4+
func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
5+
%c32 = arith.constant 32 : index
6+
%c4096 = arith.constant 4096 : index
7+
%c0 = arith.constant 0 : index
8+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
9+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
10+
// expected-remark @below {{found desc op}}
11+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
12+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
13+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
14+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
15+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
16+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
17+
scf.yield %7 : vector<256x256xf16>
18+
}
19+
return
20+
}
21+
22+
module attributes {transform.with_named_sequence} {
23+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
24+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
25+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
26+
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
27+
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
28+
transform.yield
29+
}
30+
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: @get_desc_op_c
35+
func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
36+
%c32 = arith.constant 32 : index
37+
%c4096 = arith.constant 4096 : index
538
%c0 = arith.constant 0 : index
6-
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
7-
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
839
// expected-remark @below {{found desc op}}
9-
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
10-
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
11-
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
12-
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
13-
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
40+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
41+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
42+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
43+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
44+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
45+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
46+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
47+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
48+
scf.yield %7 : vector<256x256xf16>
49+
}
1450
return
1551
}
1652

1753
module attributes {transform.with_named_sequence} {
1854
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1955
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20-
%1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
56+
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
2157
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
2258
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
2359
transform.yield
2460
}
2561
}
2662

2763
// -----
64+
2865
// CHECK-LABEL: @set_desc_layout
2966
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
3067
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0

0 commit comments

Comments
 (0)