Skip to content

Commit 300750d

Browse files
authored
[MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (llvm#166865)
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
1 parent b440fb7 commit 300750d

File tree

6 files changed

+263
-0
lines changed

6 files changed

+263
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
161161
}];
162162
}
163163

164+
def SetGPULaunchThreadsOp
165+
: Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
166+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
167+
TransformOpInterface
168+
]> {
169+
170+
let summary = "Set number of threads for a given gpu.launch operation";
171+
let description = [{
172+
Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
173+
}];
174+
175+
let arguments = (ins TransformHandleTypeInterface:$target,
176+
Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
177+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
178+
);
179+
let results = (outs);
180+
let builders = [
181+
OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
182+
];
183+
184+
let assemblyFormat = [{
185+
$target
186+
`threads` `=` custom<DynamicIndexList>($threads, $static_threads)
187+
attr-dict `:` qualified(type(operands))
188+
}];
189+
190+
let extraClassDeclaration = [{
191+
::mlir::DiagnosedSilenceableFailure apply(
192+
::mlir::transform::TransformRewriter &rewriter,
193+
::mlir::transform::TransformResults &transformResults,
194+
::mlir::transform::TransformState &state);
195+
196+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
197+
Builder b(getContext());
198+
return getMixedValues(getStaticThreads(), getThreads(), b);
199+
}
200+
}];
201+
}
202+
164203
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -341,6 +342,69 @@ void transform::SetOpLayoutAttrOp::getEffects(
341342
modifiesPayload(effects);
342343
}
343344

345+
void transform::SetGPULaunchThreadsOp::build(
346+
OpBuilder &builder, OperationState &ostate, Value target,
347+
ArrayRef<OpFoldResult> mixedThreads) {
348+
SmallVector<int64_t> staticThreads;
349+
SmallVector<Value> dynamicThreads;
350+
dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
351+
build(builder, ostate, target.getType(),
352+
/*target=*/target,
353+
/*threads=*/dynamicThreads,
354+
/*static_threads=*/staticThreads);
355+
}
356+
357+
DiagnosedSilenceableFailure
358+
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
359+
transform::TransformResults &results,
360+
transform::TransformState &state) {
361+
auto targetOps = state.getPayloadOps(getTarget());
362+
if (!llvm::hasSingleElement(targetOps)) {
363+
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
364+
<< llvm::range_size(targetOps) << ")";
365+
}
366+
Operation *target = *targetOps.begin();
367+
368+
auto launchOp = dyn_cast<gpu::LaunchOp>(target);
369+
if (!launchOp) {
370+
auto diag = emitSilenceableFailure(getLoc())
371+
<< "Expected a gpu.launch op, but got: " << target->getName();
372+
diag.attachNote(target->getLoc()) << "target op";
373+
return diag;
374+
}
375+
376+
SmallVector<int32_t> threads;
377+
DiagnosedSilenceableFailure status =
378+
convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
379+
if (!status.succeeded())
380+
return status;
381+
382+
if (threads.size() != 3) {
383+
return emitSilenceableFailure(getLoc())
384+
<< "Expected threads argument to consist of three values (got "
385+
<< threads.size() << ")";
386+
}
387+
388+
rewriter.setInsertionPoint(launchOp);
389+
auto createConstValue = [&](int value) {
390+
return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
391+
};
392+
393+
// Replace threads in-place.
394+
launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
395+
launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
396+
launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
397+
398+
return DiagnosedSilenceableFailure::success();
399+
}
400+
401+
void transform::SetGPULaunchThreadsOp::getEffects(
402+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
403+
onlyReadsHandle(getTargetMutable(), effects);
404+
onlyReadsHandle(getThreadsMutable(), effects);
405+
modifiesPayload(effects);
406+
}
407+
344408
namespace {
345409
class XeGPUTransformDialectExtension
346410
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,39 @@ def __init__(
132132
loc=loc,
133133
ip=ip,
134134
)
135+
136+
137+
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
138+
"""Specialization for SetGPULaunchThreadsOp class."""
139+
140+
def __init__(
141+
self,
142+
launch_op: Union[Operation, Value],
143+
threads: MixedValues,
144+
*,
145+
loc=None,
146+
ip=None,
147+
):
148+
(
149+
dynamic_threads,
150+
static_threads,
151+
_,
152+
) = _dispatch_dynamic_index_list(threads)
153+
154+
super().__init__(
155+
_get_op_result_or_value(launch_op),
156+
dynamic_threads,
157+
static_threads=static_threads,
158+
loc=loc,
159+
ip=ip,
160+
)
161+
162+
163+
def set_gpu_launch_threads(
164+
launch_op: Union[Operation, Value],
165+
threads: MixedValues,
166+
*,
167+
loc=None,
168+
ip=None,
169+
) -> SetGPULaunchThreadsOp:
170+
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,56 @@ module attributes {transform.with_named_sequence} {
7171
transform.yield
7272
}
7373
}
74+
75+
// -----
76+
77+
func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
78+
%c32 = arith.constant 32 : index // expected-note {{target op}}
79+
return
80+
}
81+
82+
module attributes {transform.with_named_sequence} {
83+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
84+
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
85+
// expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
86+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
87+
transform.yield
88+
}
89+
}
90+
91+
// -----
92+
93+
func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
94+
%c32 = arith.constant 32 : index
95+
%c64 = arith.constant 64 : index
96+
return
97+
}
98+
99+
module attributes {transform.with_named_sequence} {
100+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
101+
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
102+
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
103+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
104+
transform.yield
105+
}
106+
}
107+
108+
// -----
109+
110+
func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
111+
%c1 = arith.constant 1 : index
112+
%c16 = arith.constant 16 : index
113+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
114+
gpu.terminator
115+
}
116+
return
117+
}
118+
119+
module attributes {transform.with_named_sequence} {
120+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
121+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
122+
// expected-error@below {{Expected threads argument to consist of three values (got 2)}}
123+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
124+
transform.yield
125+
}
126+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ module attributes {transform.with_named_sequence} {
230230
transform.yield
231231
}
232232
}
233+
233234
// -----
234235

235236
// CHECK-LABEL: @set_op_layout_attr_operand1
@@ -252,3 +253,58 @@ module attributes {transform.with_named_sequence} {
252253
transform.yield
253254
}
254255
}
256+
257+
// -----
258+
259+
// CHECK-LABEL: @set_gpu_launch_threads
260+
func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
261+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
262+
%c1 = arith.constant 1 : index
263+
// CHECK: %[[C16:.+]] = arith.constant 16 : index
264+
%c16 = arith.constant 16 : index
265+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
266+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
267+
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
268+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
269+
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
270+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
271+
gpu.terminator
272+
}
273+
return
274+
}
275+
module attributes {transform.with_named_sequence} {
276+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
277+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
278+
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
279+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
280+
transform.yield
281+
}
282+
}
283+
284+
// -----
285+
286+
// CHECK-LABEL: @set_gpu_launch_threads_param
287+
func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
288+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
289+
%c1 = arith.constant 1 : index
290+
// CHECK: %[[C16:.+]] = arith.constant 16 : index
291+
%c16 = arith.constant 16 : index
292+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
293+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
294+
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
295+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
296+
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
297+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
298+
gpu.terminator
299+
}
300+
return
301+
}
302+
module attributes {transform.with_named_sequence} {
303+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
304+
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
305+
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
306+
%th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
307+
transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
308+
transform.yield
309+
}
310+
}

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,18 @@ def setOpLayoutAttrResult():
113113
# CHECK: sg_layout = [6, 4]
114114
# CHECK: sg_data = [32, 16]
115115
# CHECK: inst_data = [8, 16]
116+
117+
118+
@run
119+
def setGPULaunchThreadsOp():
120+
sequence = transform.SequenceOp(
121+
transform.FailurePropagationMode.Propagate,
122+
[],
123+
transform.OperationType.get("gpu.launch"),
124+
)
125+
with InsertionPoint(sequence.body):
126+
xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
127+
transform.YieldOp()
128+
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
129+
# CHECK: transform.xegpu.set_gpu_launch_threads
130+
# CHECK: threads = [8, 4, 1]

0 commit comments

Comments
 (0)