-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op #166865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Tuomas Kärnä (tkarna) ChangesAdds For reference, the rationale behind xegpu transform ops is outlined in this RFC document. Full diff: https://github.com/llvm/llvm-project/pull/166865.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..0b138d13d8aee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,41 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}
+def SetGPULaunchThreadsOp
+ : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+ ]> {
+
+ let summary = "Set number of threads for a given gpu.launch operation";
+ let description = "Set number of threads for a given `gpu.launch` operation.";
+
+ let arguments = (ins TransformHandleTypeInterface : $target,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
+ );
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `threads` `=` custom<DynamicIndexList>($threads, $static_threads)
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
+ Builder b(getContext());
+ return getMixedValues(getStaticThreads(), getThreads(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..73c3776a75276 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -193,6 +194,69 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetGPULaunchThreadsOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedThreads) {
+ SmallVector<int64_t> staticThreads;
+ SmallVector<Value> dynamicThreads;
+ dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*threads=*/dynamicThreads,
+ /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+ if (!launchOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a gpu.launch op, but got: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<int32_t> threads;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+ if (!status.succeeded())
+ return status;
+
+ if (threads.size() != 3) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected threads to be a 3D vector";
+ }
+
+ rewriter.setInsertionPoint(launchOp);
+ auto createConstValue = [&](int value) {
+ return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+ };
+
+ // Replace threads in-place.
+ launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+ launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+ launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getThreadsMutable(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..dce0e0e97c01f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,30 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
+ """Specialization for SetGPULaunchThreadsOp class."""
+
+ def __init__(
+ self,
+ launch_op: Union[Operation, Value],
+ threads: MixedValues,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ (
+ dynamic_threads,
+ static_threads,
+ _,
+ ) = _dispatch_dynamic_index_list(threads)
+
+ super().__init__(
+ _get_op_result_or_value(launch_op),
+ dynamic_threads,
+ static_threads=static_threads,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..be6943c5b8db9 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,56 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ } {SCFToGPU_visited}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Expected threads to be a 3D vector}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..cb8d1e9afcc6e 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads
+func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ } {SCFToGPU_visited}
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads_param
+func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ } {SCFToGPU_visited}
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ %th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..ff8506c749f0f 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -49,3 +49,18 @@ def setDescLayoutInstData():
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
+
+
+@run
+def setGPULaunchThreadsOp():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("gpu.lauch"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setGPULaunchThreadsOp
+ # CHECK: transform.xegpu.set_gpu_launch_threads
+ # CHECK: threads = [8, 4, 1]
|
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. Only a couple of nits.
As this is necessary for the Xe pipelines/schedules, it seems fine to prototype it in this dialect extension. When it matures and we recognize a more general need it is/could be serving, we should move it.
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
Outdated
Show resolved
Hide resolved
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
505c8a9 to
ba5aa77
Compare
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
Outdated
Show resolved
Hide resolved
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/32883 Here is the relevant piece of the build log for the reference |
Adds
transform.xegpu.set_gpu_launch_threadsthat overridesgpu.launchoperation threads.For reference, the rationale behind xegpu transform ops is outlined in this RFC document.