Skip to content

Commit 67e32af

Browse files
[mlir][python] Add pythonic interface for GPUFunc
The func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op.
1 parent f49e3d1 commit 67e32af

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,121 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from .._gpu_ops_gen import *
6+
from .._gpu_ops_gen import _Dialect
67
from .._gpu_enum_gen import *
78
from ..._mlir_libs._mlirDialectsGPU import *
9+
from typing import Callable, Sequence, Union, Optional
10+
11+
try:
12+
from ...ir import (
13+
FunctionType,
14+
TypeAttr,
15+
StringAttr,
16+
UnitAttr,
17+
Block,
18+
InsertionPoint,
19+
ArrayAttr,
20+
Type,
21+
DictAttr,
22+
Attribute,
23+
)
24+
from .._ods_common import (
25+
get_default_loc_context as _get_default_loc_context,
26+
_cext as _ods_cext,
27+
)
28+
except ImportError as e:
29+
raise RuntimeError("Error loading imports from extension module") from e
30+
31+
32+
FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
33+
KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
34+
SYM_NAME_ATTRIBUTE_NAME = "sym_name"
35+
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
36+
RESULT_ATTRIBUTE_NAME = "res_attrs"
37+
38+
39+
@_ods_cext.register_operation(_Dialect, replace=True)
40+
class GPUFuncOp(GPUFuncOp):
41+
def __init__(
42+
self,
43+
function_type: Union[FunctionType, TypeAttr],
44+
sym_name: Optional[Union[str, StringAttr]] = None,
45+
kernel: Optional[bool] = None,
46+
body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
47+
*args,
48+
loc=None,
49+
ip=None,
50+
**kwargs,
51+
):
52+
function_type = (
53+
TypeAttr.get(function_type)
54+
if not isinstance(function_type, TypeAttr)
55+
else function_type
56+
)
57+
super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
58+
if sym_name is not None:
59+
self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
60+
if kernel:
61+
self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
62+
if body_builder is not None:
63+
with InsertionPoint(self.add_entry_block()):
64+
body_builder(self)
65+
66+
@property
67+
def type(self) -> FunctionType:
68+
return FunctionType(
69+
TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
70+
)
71+
72+
@property
73+
def name(self) -> StringAttr:
74+
return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
75+
76+
@property
77+
def is_kernel(self) -> bool:
78+
return KERNEL_ATTRIBUTE_NAME in self.attributes
79+
80+
def add_entry_block(self) -> Block:
81+
function_type = self.type
82+
return self.body.blocks.append(
83+
*function_type.inputs,
84+
arg_locs=[self.location for _ in function_type.inputs],
85+
)
86+
87+
@property
88+
def entry_block(self) -> Block:
89+
return self.body.blocks[0]
90+
91+
@property
92+
def arguments(self) -> Sequence[Type]:
93+
return self.type.inputs
94+
95+
@property
96+
def arg_attrs(self):
97+
if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
98+
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
99+
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
100+
101+
@arg_attrs.setter
102+
def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
103+
if isinstance(attribute, ArrayAttr):
104+
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
105+
else:
106+
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
107+
attribute, context=self.context
108+
)
109+
110+
@property
111+
def result_attrs(self) -> Optional[ArrayAttr]:
112+
if RESULT_ATTRIBUTE_NAME not in self.attributes:
113+
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
114+
return self.attributes[RESULT_ATTRIBUTE_NAME]
115+
116+
@result_attrs.setter
117+
def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
118+
if isinstance(attribute, ArrayAttr):
119+
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
120+
else:
121+
self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
122+
attribute, context=self.context
123+
)

mlir/test/python/dialects/gpu/dialect.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import mlir.dialects.gpu as gpu
55
import mlir.dialects.gpu.passes
66
from mlir.passmanager import *
7+
import mlir.ir as ir
78

89

910
def run(f):
@@ -64,3 +65,65 @@ def testObjectAttr():
6465
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
6566
print(o)
6667
assert o.kernels == kernelTable
68+
69+
70+
# CHECK-LABEL: testGPUFuncOp
71+
@run
72+
def testGPUFuncOp():
73+
module = Module.create()
74+
with InsertionPoint(module.body):
75+
gpu_module_name = StringAttr.get("gpu_module")
76+
gpumodule = gpu.GPUModuleOp(gpu_module_name)
77+
block = gpumodule.bodyRegion.blocks.append()
78+
79+
def builder(func: gpu.GPUFuncOp) -> None:
80+
_ = gpu.GlobalIdOp(gpu.Dimension.x)
81+
_ = gpu.ReturnOp([])
82+
83+
with InsertionPoint(block):
84+
name = StringAttr.get("kernel0")
85+
func_type = ir.FunctionType.get(inputs=[], results=[])
86+
type_attr = TypeAttr.get(func_type)
87+
func = gpu.GPUFuncOp(type_attr, name)
88+
func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
89+
func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
90+
block = func.body.blocks.append()
91+
with InsertionPoint(block):
92+
builder(func)
93+
94+
func = gpu.GPUFuncOp(
95+
func_type,
96+
sym_name="kernel1",
97+
kernel=True,
98+
body_builder=builder,
99+
)
100+
101+
assert func.name.value == "kernel1"
102+
assert func.arg_attrs == ArrayAttr.get([])
103+
assert func.result_attrs == ArrayAttr.get([])
104+
assert func.arguments == []
105+
assert func.entry_block == func.body.blocks[0]
106+
assert func.is_kernel
107+
108+
non_kernel_func = gpu.GPUFuncOp(
109+
func_type,
110+
sym_name="non_kernel_func",
111+
body_builder=builder,
112+
)
113+
assert not non_kernel_func.is_kernel
114+
115+
print(module)
116+
117+
# CHECK: gpu.module @gpu_module
118+
# CHECK: gpu.func @kernel0() kernel {
119+
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
120+
# CHECK: gpu.return
121+
# CHECK: }
122+
# CHECK: gpu.func @kernel1() kernel {
123+
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
124+
# CHECK: gpu.return
125+
# CHECK: }
126+
# CHECK: gpu.func @non_kernel_func() {
127+
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
128+
# CHECK: gpu.return
129+
# CHECK: }

0 commit comments

Comments
 (0)