Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/docs/Dialects/GPU.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func.func @main() {
gpu.launch
blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1)
threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) {
gpu.printf "Hello from %d\n" %6 : index
gpu.printf "Hello from %d\n", %6 : index
gpu.terminator
}
return
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
This operation provides a memref pointer to the start of dynamic shared
memory, often referred to as workgroup memory. It's important to note that
this dynamic shared memory needs to be allocated at kernel launch. One can
conveniently utilize `the dynamic_shared_memory_size` parameter of
conveniently utilize the `dynamic_shared_memory_size` parameter of
`gpu.launch` for this purpose.

Examples:
Expand Down
184 changes: 183 additions & 1 deletion mlir/python/mlir/dialects/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .._gpu_ops_gen import _Dialect
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
from typing import Callable, Sequence, Union, Optional, List
from typing import Any, Callable, Sequence, Tuple, Union, Optional, List

try:
from ...ir import (
Expand All @@ -21,15 +21,24 @@
DictAttr,
Attribute,
DenseI32ArrayAttr,
Value,
)
from ...extras.meta import region_op
from ...extras import types as T
from ..arith import constant, ConstantOp
from .._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
get_op_result_or_op_results,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e


def gpu_async_token():
return Type.parse("!gpu.async.token")


@_ods_cext.register_operation(_Dialect, replace=True)
class GPUFuncOp(GPUFuncOp):
__doc__ = GPUFuncOp.__doc__
Expand Down Expand Up @@ -151,3 +160,176 @@ def entry_block(self) -> Block:
@property
def arguments(self) -> Sequence[Type]:
return self.function_type.value.inputs


def _convert_literal_to_constant(value: Union[int, ConstantOp, Value]) -> Value:
if isinstance(value, int):
return constant(T.index(), value)
elif isinstance(value, (ConstantOp, Value)):
return value
else:
raise ValueError(f"Invalid value: {value}")


@_ods_cext.register_operation(_Dialect, replace=True)
class LaunchFuncOp(LaunchFuncOp):
__doc__ = LaunchFuncOp.__doc__

def __init__(
self,
kernel: List[str],
grid_size: Tuple[Any, Any, Any],
block_size: Tuple[Any, Any, Any],
kernel_operands: Optional[List[Value]] = None,
async_dependencies: Optional[List[Value]] = None,
dynamic_shared_memory_size: Optional[Value] = None,
async_object=None,
*,
loc=None,
ip=None,
):
if async_dependencies is None:
async_dependencies = []
async_token = None
if len(async_dependencies):
async_token = gpu_async_token()

grid_size_x, grid_size_y, grid_size_z = map(
_convert_literal_to_constant, grid_size
)
block_size_x, block_size_y, block_size_z = map(
_convert_literal_to_constant, block_size
)

super().__init__(
async_token,
async_dependencies,
kernel,
grid_size_x,
grid_size_y,
grid_size_z,
block_size_x,
block_size_y,
block_size_z,
kernel_operands,
dynamicSharedMemorySize=dynamic_shared_memory_size,
asyncObject=async_object,
loc=loc,
ip=ip,
)


def launch_func(
kernel: List[str],
grid_size: Tuple[Any, Any, Any],
block_size: Tuple[Any, Any, Any],
kernel_operands: Optional[List[Value]] = None,
async_dependencies: Optional[List[Value]] = None,
dynamic_shared_memory_size: Optional[Value] = None,
async_object=None,
*,
loc=None,
ip=None,
) -> Union[Value, List[Value], LaunchFuncOp]:
op = LaunchFuncOp(
kernel=kernel,
grid_size=grid_size,
block_size=block_size,
kernel_operands=kernel_operands,
async_dependencies=async_dependencies,
dynamic_shared_memory_size=dynamic_shared_memory_size,
async_object=async_object,
loc=loc,
ip=ip,
)
results = op.results
if len(results) == 1:
return results[0]
elif len(results) > 1:
return results
else:
return op


def wait(
async_dependencies: Optional[List[Value]] = None, *, loc=None, ip=None
) -> Union[Value, List[Value], WaitOp]:
if async_dependencies is None:
async_dependencies = []
return get_op_result_or_op_results(
WaitOp(gpu_async_token(), async_dependencies, loc=loc, ip=ip)
)


@_ods_cext.register_operation(_Dialect, replace=True)
class LaunchOp(LaunchOp):
__doc__ = LaunchOp.__doc__

def __init__(
self,
grid_size: Tuple[Any, Any, Any],
block_size: Tuple[Any, Any, Any],
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
*,
loc=None,
ip=None,
):
if async_dependencies is None:
async_dependencies = []
async_token = None
if len(async_dependencies):
async_token = gpu_async_token()
grid_size_x, grid_size_y, grid_size_z = map(
_convert_literal_to_constant, grid_size
)
block_size_x, block_size_y, block_size_z = map(
_convert_literal_to_constant, block_size
)

super().__init__(
async_token,
async_dependencies,
grid_size_x,
grid_size_y,
grid_size_z,
block_size_x,
block_size_y,
block_size_z,
dynamicSharedMemorySize=dynamic_shared_memory_size,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[T.index() for _ in range(12)])


def launch_(
grid_size: Tuple[Any, Any, Any],
block_size: Tuple[Any, Any, Any],
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
*,
loc=None,
ip=None,
):
grid_size = tuple(map(_convert_literal_to_constant, grid_size))
block_size = tuple(map(_convert_literal_to_constant, block_size))
launch_op = LaunchOp(
grid_size,
block_size,
async_dependencies,
dynamic_shared_memory_size,
loc=loc,
ip=ip,
)
return launch_op


launch = region_op(launch_, terminator=lambda *_args: terminator())


_printf = printf


def printf(format, *args, loc=None, ip=None):
return _printf(format=format, args=args, loc=loc, ip=ip)
Comment on lines +334 to +335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok lol this is a good use of *args

96 changes: 95 additions & 1 deletion mlir/test/python/dialects/gpu/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from mlir.ir import *
import mlir.ir as ir
import mlir.dialects.gpu as gpu
from mlir.dialects import gpu, func, arith, math
from mlir.extras import types as T
import mlir.dialects.gpu.passes
from mlir.passmanager import *

Expand Down Expand Up @@ -157,3 +158,96 @@ def builder(func: gpu.GPUFuncOp) -> None:
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
# CHECK: gpu.return
# CHECK: }


# CHECK-LABEL: testGPULaunchFuncOp
@run
def testGPULaunchFuncOp():
module = Module.create()

module.operation.attributes["gpu.container_module"] = UnitAttr.get()
with InsertionPoint(module.body):
gpu_module = gpu.GPUModuleOp("gpu_module")
block = gpu_module.bodyRegion.blocks.append()

with InsertionPoint(block):
gpu_func = gpu.GPUFuncOp(
FunctionType.get([], []),
"kernel",
body_builder=lambda func: gpu.return_([]),
kernel=True,
)

with InsertionPoint(module.body):
host = func.FuncOp(type=FunctionType.get([], []), name="host")

with InsertionPoint(host.add_entry_block()):
c1 = arith.constant(T.index(), 1)
grid_sizes = (1, 1, 1)
block_sizes = (1, 1, 1)
token = gpu.wait()
token = gpu.launch_func(
Copy link
Contributor

@makslevental makslevental Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just grab this https://github.com/makslevental/mlir-python-extras/blob/main/mlir/extras/dialects/ext/gpu.py#L339-L379 (which supports exactly what you're saying - 3-tuples), put it in mlir/dialects/gpu.py (along with the standard register_operation thing) and then just below wrap it in launch_func (thereby shadowing auto-generated launch_func).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That all looks great to me, but are we worried about forcing folks to update their code if it depends on the Python bindings?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? oh you're saying since this meaningfully changes the signature of both these existing APIs (LaunchFuncOp and launch_func) relative to the auto-generated ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - anyone using either api will need to update their code. That's annoying! But maybe worth it. Just making sure 😄

Copy link
Contributor

@makslevental makslevental Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python APIs aren't stable (ie we make no stability guarantees). So basically this same "breakage" occurs whenever we add one of these nicer builders. Also there's a simple "migration path": people can just import the generated original APIs directly from _gpu_ops_gen if they really want to keep using the old (auto-generated) APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since users always have an easy way to get their old bindings back, I feel better about adding the Python-extras builders. I'll pull in the builder you linked, and I'll be ready to merge if I get an OK from Mehdi or Guray. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @makslevental about the stability guarantees. MLIR is more progressive compared to LLVM — there’s no API stability guarantee as long as you can migrate to something better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know 😃

async_dependencies=[token],
kernel=[gpu_module.sym_name.value, gpu_func.name.value],
grid_size=grid_sizes,
block_size=block_sizes,
kernel_operands=[],
)
gpu.wait(async_dependencies=[token])
func.ReturnOp([])

print(module)

# CHECK-LABEL: gpu.module @gpu_module {
# CHECK: gpu.func @kernel() kernel {
# CHECK: gpu.return
# CHECK: }
# CHECK: }

# CHECK-LABEL: func.func @host() {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
# CHECK: %[[WAIT_0:.*]] = gpu.wait async
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : index
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : index
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 1 : index
# CHECK: %[[CONSTANT_6:.*]] = arith.constant 1 : index
# CHECK: %[[LAUNCH_FUNC_0:.*]] = gpu.launch_func async {{\[}}%[[WAIT_0]]] @gpu_module::@kernel blocks in (%[[CONSTANT_1]], %[[CONSTANT_2]], %[[CONSTANT_3]]) threads in (%[[CONSTANT_4]], %[[CONSTANT_5]], %[[CONSTANT_6]])
# CHECK: %[[WAIT_1:.*]] = gpu.wait async {{\[}}%[[LAUNCH_FUNC_0]]]
# CHECK: return
# CHECK: }


# CHECK-LABEL: testGPULaunchOp
@run
def testGPULaunchOp():
module = Module.create()

with InsertionPoint(module.body):
host = func.FuncOp(type=FunctionType.get([T.f32()], []), name="gpu_printf")

entry_block = host.add_entry_block()
with InsertionPoint(entry_block):
c1 = arith.constant(T.index(), 1)
grid_sizes = (c1, c1, c1)
block_sizes = (c1, c1, c1)

launch = gpu.launch(grid_sizes, block_sizes)

op = launch(lambda *args: gpu.printf("%f", args[0]))

with InsertionPoint(entry_block):
func.ReturnOp([])

print(module)

# CHECK-LABEL: func.func @gpu_printf(
# CHECK-SAME: %[[ARG0:.*]]: f32) {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
# CHECK: gpu.launch blocks(%[[VAL_0:.*]], %[[VAL_1:.*]], %[[VAL_2:.*]]) in (%[[VAL_3:.*]] = %[[CONSTANT_0]], %[[VAL_4:.*]] = %[[CONSTANT_0]], %[[VAL_5:.*]] = %[[CONSTANT_0]]) threads(%[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]]) in (%[[VAL_9:.*]] = %[[CONSTANT_0]], %[[VAL_10:.*]] = %[[CONSTANT_0]], %[[VAL_11:.*]] = %[[CONSTANT_0]]) {
# CHECK: gpu.printf "%[[VAL_12:.*]]", %[[VAL_0]] : index
# CHECK: gpu.terminator
# CHECK: }
# CHECK: return
# CHECK: }