Skip to content

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Nov 6, 2025

Adds transform.xegpu.set_gpu_launch_threads that overrides gpu.launch operation threads.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

@tkarna
Copy link
Contributor Author

tkarna commented Nov 6, 2025

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Tuomas Kärnä (tkarna)

Changes

Adds transform.xegpu.set_gpu_launch_threads that overrides gpu.launch operation threads.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.


Full diff: https://github.com/llvm/llvm-project/pull/166865.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+37)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+64)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+27)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir (+53)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+55)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+15)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..0b138d13d8aee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,41 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
+def SetGPULaunchThreadsOp
+    : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      TransformOpInterface
+    ]> {
+
+  let summary = "Set number of threads for a given gpu.launch operation";
+  let description = "Set number of threads for a given `gpu.launch` operation.";
+
+  let arguments = (ins TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
+                   );
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `threads` `=` custom<DynamicIndexList>($threads, $static_threads)
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
+      Builder b(getContext());
+      return getMixedValues(getStaticThreads(), getThreads(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..73c3776a75276 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -193,6 +194,69 @@ void transform::SetDescLayoutOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::SetGPULaunchThreadsOp::build(
+    OpBuilder &builder, OperationState &ostate, Value target,
+    ArrayRef<OpFoldResult> mixedThreads) {
+  SmallVector<int64_t> staticThreads;
+  SmallVector<Value> dynamicThreads;
+  dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*threads=*/dynamicThreads,
+        /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+  if (!launchOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a gpu.launch op, but got: " << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  SmallVector<int32_t> threads;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+  if (!status.succeeded())
+    return status;
+
+  if (threads.size() != 3) {
+    return emitSilenceableFailure(getLoc())
+           << "Expected threads to be a 3D vector";
+  }
+
+  rewriter.setInsertionPoint(launchOp);
+  auto createConstValue = [&](int value) {
+    return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+  };
+
+  // Replace threads in-place.
+  launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+  launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+  launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getThreadsMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..dce0e0e97c01f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,30 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
+    """Specialization for SetGPULaunchThreadsOp class."""
+
+    def __init__(
+        self,
+        launch_op: Union[Operation, Value],
+        threads: MixedValues,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        (
+            dynamic_threads,
+            static_threads,
+            _,
+        ) = _dispatch_dynamic_index_list(threads)
+
+        super().__init__(
+            _get_op_result_or_value(launch_op),
+            dynamic_threads,
+            static_threads=static_threads,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..be6943c5b8db9 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,56 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Expected threads to be a 3D vector}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..cb8d1e9afcc6e 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,58 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads
+func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[C16:.+]] = arith.constant 16 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: %[[C8:.+]] = arith.constant 8 : index
+  // CHECK: %[[C4:.+]] = arith.constant 4 : index
+  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+  // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+  // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads_param
+func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[C16:.+]] = arith.constant 16 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: %[[C8:.+]] = arith.constant 8 : index
+  // CHECK: %[[C4:.+]] = arith.constant 4 : index
+  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+  // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+  // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+    %th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..ff8506c749f0f 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -49,3 +49,18 @@ def setDescLayoutInstData():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+@run
+def setGPULaunchThreadsOp():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("gpu.lauch"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setGPULaunchThreadsOp
+    # CHECK: transform.xegpu.set_gpu_launch_threads
+    # CHECK: threads = [8, 4, 1]

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. Only a couple of nits.

As this is necessary for the Xe pipelines/schedules, it seems fine to prototype it in this dialect extension. When it matures and we recognize a more general need it is/could be serving, we should move it.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tkarna tkarna force-pushed the xegpu-tr-ops-set-gpu-threads branch from 505c8a9 to ba5aa77 Compare November 11, 2025 06:59
@rolfmorel rolfmorel merged commit 300750d into llvm:main Nov 11, 2025
10 checks passed
@tkarna tkarna deleted the xegpu-tr-ops-set-gpu-threads branch November 11, 2025 11:58
WillFroom added a commit to WillFroom/llvm-project that referenced this pull request Nov 11, 2025
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 11, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-mlir-rhel-clang running on ppc64le-mlir-rhel-test while building mlir at step 6 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/32883

Here is the relevant piece of the build log for the reference
Step 6 (test-build-check-mlir-build-only-check-mlir) failure: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
5.550 [0/1/0] Running the MLIR regression tests
command timed out: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=1206.059447

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants