66from .._gpu_ops_gen import _Dialect
77from .._gpu_enum_gen import *
88from ..._mlir_libs ._mlirDialectsGPU import *
9- from typing import Callable , Sequence , Union , Optional
9+ from typing import Callable , Sequence , Union , Optional , List
1010
1111try :
1212 from ...ir import (
2020 Type ,
2121 DictAttr ,
2222 Attribute ,
23+ DenseI32ArrayAttr ,
2324 )
2425 from .._ods_common import (
2526 get_default_loc_context as _get_default_loc_context ,
2930 raise RuntimeError ("Error loading imports from extension module" ) from e
3031
3132
32- FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
3333KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
34+ KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME = "gpu.known_block_size"
35+ KNOWN_GRID_SIZE_ATTRIBUTE_NAME = "gpu.known_grid_size"
36+
37+ FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
3438SYM_NAME_ATTRIBUTE_NAME = "sym_name"
3539ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
3640RESULT_ATTRIBUTE_NAME = "res_attrs"
3741
38-
3942@_ods_cext .register_operation (_Dialect , replace = True )
4043class GPUFuncOp (GPUFuncOp ):
44+ __doc__ = GPUFuncOp .__doc__
45+
4146 def __init__ (
4247 self ,
4348 function_type : Union [FunctionType , TypeAttr ],
4449 sym_name : Optional [Union [str , StringAttr ]] = None ,
4550 kernel : Optional [bool ] = None ,
4651 body_builder : Optional [Callable [[GPUFuncOp ], None ]] = None ,
52+ known_block_size : Optional [Union [List [int ], DenseI32ArrayAttr ]] = None ,
53+ known_grid_size : Optional [Union [List [int ], DenseI32ArrayAttr ]] = None ,
4754 * args ,
4855 loc = None ,
4956 ip = None ,
5057 ** kwargs ,
5158 ):
59+ """
60+ Create a GPUFuncOp with the provided `function_type`, `sym_name`, `kernel`, `body_builder`, `known_block_size`, and `known_grid_size`.
61+ - `function_type` is a FunctionType or a TypeAttr.
62+ - `sym_name` is a string or a StringAttr representing the function name.
63+ - `kernel` is a boolean representing whether the function is a kernel.
64+ - `body_builder` is an optional callback. When provided, a new entry block
65+ is created and the callback is invoked with the new op as argument within
66+ an InsertionPoint context already set for the block. The callback is
67+ expected to insert a terminator in the block.
68+ - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
69+ - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
70+ """
5271 function_type = (
5372 TypeAttr .get (function_type )
5473 if not isinstance (function_type , TypeAttr )
@@ -62,6 +81,20 @@ def __init__(
6281 if body_builder is not None :
6382 with InsertionPoint (self .add_entry_block ()):
6483 body_builder (self )
84+ if known_block_size is not None :
85+ if isinstance (known_block_size , list ):
86+ self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ] = (
87+ DenseI32ArrayAttr .get (known_block_size )
88+ )
89+ else :
90+ self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ] = known_block_size
91+ if known_grid_size is not None :
92+ if isinstance (known_grid_size , list ):
93+ self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ] = (
94+ DenseI32ArrayAttr .get (known_grid_size )
95+ )
96+ else :
97+ self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ] = known_grid_size
6598
6699 @property
67100 def type (self ) -> FunctionType :
@@ -77,6 +110,18 @@ def name(self) -> StringAttr:
77110 def is_kernel (self ) -> bool :
78111 return KERNEL_ATTRIBUTE_NAME in self .attributes
79112
113+ @property
114+ def known_block_size (self ) -> Optional [DenseI32ArrayAttr ]:
115+ if KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME not in self .attributes :
116+ return None
117+ return self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ]
118+
119+ @property
120+ def known_grid_size (self ) -> Optional [DenseI32ArrayAttr ]:
121+ if KNOWN_GRID_SIZE_ATTRIBUTE_NAME not in self .attributes :
122+ return None
123+ return self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ]
124+
80125 def add_entry_block (self ) -> Block :
81126 function_type = self .type
82127 return self .body .blocks .append (
0 commit comments