@@ -48,32 +48,42 @@ def __init__(
4848 function_type : Union [FunctionType , TypeAttr ],
4949 sym_name : Optional [Union [str , StringAttr ]] = None ,
5050 kernel : Optional [bool ] = None ,
51- body_builder : Optional [Callable [[GPUFuncOp ], None ]] = None ,
51+ workgroup_attrib_attrs : Optional [Sequence [dict ]] = None ,
52+ private_attrib_attrs : Optional [Sequence [dict ]] = None ,
5253 known_block_size : Optional [Union [Sequence [int ], DenseI32ArrayAttr ]] = None ,
5354 known_grid_size : Optional [Union [Sequence [int ], DenseI32ArrayAttr ]] = None ,
54- * args ,
5555 loc = None ,
5656 ip = None ,
57- ** kwargs ,
57+ body_builder : Optional [ Callable [[ GPUFuncOp ], None ]] = None ,
5858 ):
5959 """
60- Create a GPUFuncOp with the provided `function_type`, `sym_name`, `kernel`, `body_builder`, `known_block_size`, and `known_grid_size`.
60+ Create a GPUFuncOp with the provided `function_type`, `sym_name`,
61+ `kernel`, `workgroup_attrib_attrs`, `private_attrib_attrs`, `known_block_size`,
62+ `known_grid_size`, and `body_builder`.
6163 - `function_type` is a FunctionType or a TypeAttr.
6264 - `sym_name` is a string or a StringAttr representing the function name.
6365 - `kernel` is a boolean representing whether the function is a kernel.
66+ - `workgroup_attrib_attrs` is an optional list of dictionaries.
67+ - `private_attrib_attrs` is an optional list of dictionaries.
68+ - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
69+ - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
6470 - `body_builder` is an optional callback. When provided, a new entry block
6571 is created and the callback is invoked with the new op as argument within
6672 an InsertionPoint context already set for the block. The callback is
6773 expected to insert a terminator in the block.
68- - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
69- - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
7074 """
7175 function_type = (
7276 TypeAttr .get (function_type )
7377 if not isinstance (function_type , TypeAttr )
7478 else function_type
7579 )
76- super ().__init__ (function_type , * args , loc = loc , ip = ip , ** kwargs )
80+ super ().__init__ (
81+ function_type ,
82+ workgroup_attrib_attrs = workgroup_attrib_attrs ,
83+ private_attrib_attrs = private_attrib_attrs ,
84+ loc = loc ,
85+ ip = ip ,
86+ )
7787
7888 if isinstance (sym_name , str ):
7989 self .attributes [self .SYM_NAME_ATTR_NAME ] = StringAttr .get (sym_name )
0 commit comments