Skip to content

Commit 5fb499a

Browse files
committed
[mlir][xegpu][transformops] add set_gpu_launch_threads op
1 parent bda7289 commit 5fb499a

File tree

6 files changed

+251
-0
lines changed

6 files changed

+251
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,41 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
7878
}];
7979
}
8080

81+
def SetGPULaunchThreadsOp
82+
: Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
83+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
84+
TransformOpInterface
85+
]> {
86+
87+
let summary = "Set number of threads for a given gpu.launch operation";
88+
let description = "Set number of threads for a given `gpu.launch` operation.";
89+
90+
let arguments = (ins TransformHandleTypeInterface : $target,
91+
Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
92+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
93+
);
94+
let results = (outs);
95+
let builders = [
96+
OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
97+
];
98+
99+
let assemblyFormat = [{
100+
$target
101+
`threads` `=` custom<DynamicIndexList>($threads, $static_threads)
102+
attr-dict `:` qualified(type(operands))
103+
}];
104+
105+
let extraClassDeclaration = [{
106+
::mlir::DiagnosedSilenceableFailure apply(
107+
::mlir::transform::TransformRewriter &rewriter,
108+
::mlir::transform::TransformResults &transformResults,
109+
::mlir::transform::TransformState &state);
110+
111+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
112+
Builder b(getContext());
113+
return getMixedValues(getStaticThreads(), getThreads(), b);
114+
}
115+
}];
116+
}
117+
81118
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -193,6 +194,69 @@ void transform::SetDescLayoutOp::getEffects(
193194
modifiesPayload(effects);
194195
}
195196

197+
void transform::SetGPULaunchThreadsOp::build(
198+
OpBuilder &builder, OperationState &ostate, Value target,
199+
ArrayRef<OpFoldResult> mixedThreads) {
200+
SmallVector<int64_t> staticThreads;
201+
SmallVector<Value> dynamicThreads;
202+
dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
203+
build(builder, ostate, target.getType(),
204+
/*target=*/target,
205+
/*threads=*/dynamicThreads,
206+
/*static_threads=*/staticThreads);
207+
}
208+
209+
DiagnosedSilenceableFailure
210+
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
211+
transform::TransformResults &results,
212+
transform::TransformState &state) {
213+
214+
auto targetOps = state.getPayloadOps(getTarget());
215+
if (!llvm::hasSingleElement(targetOps)) {
216+
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
217+
<< llvm::range_size(targetOps) << ")";
218+
}
219+
Operation *target = *targetOps.begin();
220+
221+
auto launchOp = dyn_cast<gpu::LaunchOp>(target);
222+
if (!launchOp) {
223+
auto diag = emitSilenceableFailure(getLoc())
224+
<< "Expected a gpu.launch op, but got: " << target->getName();
225+
diag.attachNote(target->getLoc()) << "target op";
226+
return diag;
227+
}
228+
229+
SmallVector<int32_t> threads;
230+
DiagnosedSilenceableFailure status =
231+
convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
232+
if (!status.succeeded())
233+
return status;
234+
235+
if (threads.size() != 3) {
236+
return emitSilenceableFailure(getLoc())
237+
<< "Expected threads to be a 3D vector";
238+
}
239+
240+
rewriter.setInsertionPoint(launchOp);
241+
auto createConstValue = [&](int value) {
242+
return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
243+
};
244+
245+
// Replace threads in-place.
246+
launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
247+
launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
248+
launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
249+
250+
return DiagnosedSilenceableFailure::success();
251+
}
252+
253+
void transform::SetGPULaunchThreadsOp::getEffects(
254+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
255+
onlyReadsHandle(getTargetMutable(), effects);
256+
onlyReadsHandle(getThreadsMutable(), effects);
257+
modifiesPayload(effects);
258+
}
259+
196260
namespace {
197261
class XeGPUTransformDialectExtension
198262
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,30 @@ def __init__(
6464
loc=loc,
6565
ip=ip,
6666
)
67+
68+
69+
@_ods_cext.register_operation(_Dialect, replace=True)
70+
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
71+
"""Specialization for SetGPULaunchThreadsOp class."""
72+
73+
def __init__(
74+
self,
75+
launch_op: Union[Operation, Value],
76+
threads: MixedValues,
77+
*,
78+
loc=None,
79+
ip=None,
80+
):
81+
(
82+
dynamic_threads,
83+
static_threads,
84+
_,
85+
) = _dispatch_dynamic_index_list(threads)
86+
87+
super().__init__(
88+
_get_op_result_or_value(launch_op),
89+
dynamic_threads,
90+
static_threads=static_threads,
91+
loc=loc,
92+
ip=ip,
93+
)

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,56 @@ module attributes {transform.with_named_sequence} {
1313
transform.yield
1414
}
1515
}
16+
17+
// -----
18+
19+
func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
20+
%c32 = arith.constant 32 : index // expected-note {{target op}}
21+
return
22+
}
23+
24+
module attributes {transform.with_named_sequence} {
25+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
26+
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
27+
// expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
28+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
29+
transform.yield
30+
}
31+
}
32+
33+
// -----
34+
35+
func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
36+
%c32 = arith.constant 32 : index
37+
%c64 = arith.constant 64 : index
38+
return
39+
}
40+
41+
module attributes {transform.with_named_sequence} {
42+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
43+
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
44+
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
45+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
46+
transform.yield
47+
}
48+
}
49+
50+
// -----
51+
52+
func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
53+
%c1 = arith.constant 1 : index
54+
%c16 = arith.constant 16 : index
55+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
56+
gpu.terminator
57+
} {SCFToGPU_visited}
58+
return
59+
}
60+
61+
module attributes {transform.with_named_sequence} {
62+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
63+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
64+
// expected-error@below {{Expected threads to be a 3D vector}}
65+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
66+
transform.yield
67+
}
68+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,58 @@ module attributes {transform.with_named_sequence} {
5656
transform.yield
5757
}
5858
}
59+
60+
// -----
61+
62+
// CHECK-LABEL: @set_gpu_launch_threads
63+
func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
64+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
65+
%c1 = arith.constant 1 : index
66+
// CHECK: %[[C16:.+]] = arith.constant 16 : index
67+
%c16 = arith.constant 16 : index
68+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
69+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
70+
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
71+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
72+
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
73+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
74+
gpu.terminator
75+
} {SCFToGPU_visited}
76+
return
77+
}
78+
module attributes {transform.with_named_sequence} {
79+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
80+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
81+
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
82+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
83+
transform.yield
84+
}
85+
}
86+
87+
// -----
88+
89+
// CHECK-LABEL: @set_gpu_launch_threads_param
90+
func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
91+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
92+
%c1 = arith.constant 1 : index
93+
// CHECK: %[[C16:.+]] = arith.constant 16 : index
94+
%c16 = arith.constant 16 : index
95+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
96+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
97+
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
98+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
99+
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
100+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
101+
gpu.terminator
102+
} {SCFToGPU_visited}
103+
return
104+
}
105+
module attributes {transform.with_named_sequence} {
106+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
107+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
108+
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
109+
%th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
110+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
111+
transform.yield
112+
}
113+
}

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,18 @@ def setDescLayoutInstData():
4949
# CHECK: sg_layout = [6, 4]
5050
# CHECK: sg_data = [32, 16]
5151
# CHECK: inst_data = [8, 16]
52+
53+
54+
@run
55+
def setGPULaunchThreadsOp():
56+
sequence = transform.SequenceOp(
57+
transform.FailurePropagationMode.Propagate,
58+
[],
59+
transform.OperationType.get("gpu.lauch"),
60+
)
61+
with InsertionPoint(sequence.body):
62+
xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])
63+
transform.YieldOp()
64+
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
65+
# CHECK: transform.xegpu.set_gpu_launch_threads
66+
# CHECK: threads = [8, 4, 1]

0 commit comments

Comments
 (0)