|
6 | 6 | from .._gpu_ops_gen import _Dialect |
7 | 7 | from .._gpu_enum_gen import * |
8 | 8 | from ..._mlir_libs._mlirDialectsGPU import * |
9 | | -from typing import Callable, Sequence, Union, Optional, List |
| 9 | +from typing import Any, Callable, Sequence, Tuple, Union, Optional, List |
10 | 10 |
|
11 | 11 | try: |
12 | 12 | from ...ir import ( |
|
21 | 21 | DictAttr, |
22 | 22 | Attribute, |
23 | 23 | DenseI32ArrayAttr, |
| 24 | + Value, |
24 | 25 | ) |
| 26 | + from ...extras.meta import region_op |
| 27 | + from ...extras import types as T |
| 28 | + from ..arith import constant, ConstantOp |
25 | 29 | from .._ods_common import ( |
26 | 30 | get_default_loc_context as _get_default_loc_context, |
27 | 31 | _cext as _ods_cext, |
| 32 | + get_op_result_or_op_results, |
28 | 33 | ) |
29 | 34 | except ImportError as e: |
30 | 35 | raise RuntimeError("Error loading imports from extension module") from e |
31 | 36 |
|
32 | 37 |
|
| 38 | +def gpu_async_token(): |
| 39 | + return Type.parse("!gpu.async.token") |
| 40 | + |
| 41 | + |
33 | 42 | @_ods_cext.register_operation(_Dialect, replace=True) |
34 | 43 | class GPUFuncOp(GPUFuncOp): |
35 | 44 | __doc__ = GPUFuncOp.__doc__ |
@@ -151,3 +160,176 @@ def entry_block(self) -> Block: |
151 | 160 | @property |
152 | 161 | def arguments(self) -> Sequence[Type]: |
153 | 162 | return self.function_type.value.inputs |
| 163 | + |
| 164 | + |
| 165 | +def _convert_literal_to_constant(value: Union[int, ConstantOp, Value]) -> Value: |
| 166 | + if isinstance(value, int): |
| 167 | + return constant(T.index(), value) |
| 168 | + elif isinstance(value, (ConstantOp, Value)): |
| 169 | + return value |
| 170 | + else: |
| 171 | + raise ValueError(f"Invalid value: {value}") |
| 172 | + |
| 173 | + |
| 174 | +@_ods_cext.register_operation(_Dialect, replace=True) |
| 175 | +class LaunchFuncOp(LaunchFuncOp): |
| 176 | + __doc__ = LaunchFuncOp.__doc__ |
| 177 | + |
| 178 | + def __init__( |
| 179 | + self, |
| 180 | + kernel: List[str], |
| 181 | + grid_size: Tuple[Any, Any, Any], |
| 182 | + block_size: Tuple[Any, Any, Any], |
| 183 | + kernel_operands: Optional[List[Value]] = None, |
| 184 | + async_dependencies: Optional[List[Value]] = None, |
| 185 | + dynamic_shared_memory_size: Optional[Value] = None, |
| 186 | + async_object=None, |
| 187 | + *, |
| 188 | + loc=None, |
| 189 | + ip=None, |
| 190 | + ): |
| 191 | + if async_dependencies is None: |
| 192 | + async_dependencies = [] |
| 193 | + async_token = None |
| 194 | + if len(async_dependencies): |
| 195 | + async_token = gpu_async_token() |
| 196 | + |
| 197 | + grid_size_x, grid_size_y, grid_size_z = map( |
| 198 | + _convert_literal_to_constant, grid_size |
| 199 | + ) |
| 200 | + block_size_x, block_size_y, block_size_z = map( |
| 201 | + _convert_literal_to_constant, block_size |
| 202 | + ) |
| 203 | + |
| 204 | + super().__init__( |
| 205 | + async_token, |
| 206 | + async_dependencies, |
| 207 | + kernel, |
| 208 | + grid_size_x, |
| 209 | + grid_size_y, |
| 210 | + grid_size_z, |
| 211 | + block_size_x, |
| 212 | + block_size_y, |
| 213 | + block_size_z, |
| 214 | + kernel_operands, |
| 215 | + dynamicSharedMemorySize=dynamic_shared_memory_size, |
| 216 | + asyncObject=async_object, |
| 217 | + loc=loc, |
| 218 | + ip=ip, |
| 219 | + ) |
| 220 | + |
| 221 | + |
| 222 | +def launch_func( |
| 223 | + kernel: List[str], |
| 224 | + grid_size: Tuple[Any, Any, Any], |
| 225 | + block_size: Tuple[Any, Any, Any], |
| 226 | + kernel_operands: Optional[List[Value]] = None, |
| 227 | + async_dependencies: Optional[List[Value]] = None, |
| 228 | + dynamic_shared_memory_size: Optional[Value] = None, |
| 229 | + async_object=None, |
| 230 | + *, |
| 231 | + loc=None, |
| 232 | + ip=None, |
| 233 | +) -> Union[Value, List[Value], LaunchFuncOp]: |
| 234 | + op = LaunchFuncOp( |
| 235 | + kernel=kernel, |
| 236 | + grid_size=grid_size, |
| 237 | + block_size=block_size, |
| 238 | + kernel_operands=kernel_operands, |
| 239 | + async_dependencies=async_dependencies, |
| 240 | + dynamic_shared_memory_size=dynamic_shared_memory_size, |
| 241 | + async_object=async_object, |
| 242 | + loc=loc, |
| 243 | + ip=ip, |
| 244 | + ) |
| 245 | + results = op.results |
| 246 | + if len(results) == 1: |
| 247 | + return results[0] |
| 248 | + elif len(results) > 1: |
| 249 | + return results |
| 250 | + else: |
| 251 | + return op |
| 252 | + |
| 253 | + |
| 254 | +def wait( |
| 255 | + async_dependencies: Optional[List[Value]] = None, *, loc=None, ip=None |
| 256 | +) -> Union[Value, List[Value], WaitOp]: |
| 257 | + if async_dependencies is None: |
| 258 | + async_dependencies = [] |
| 259 | + return get_op_result_or_op_results( |
| 260 | + WaitOp(gpu_async_token(), async_dependencies, loc=loc, ip=ip) |
| 261 | + ) |
| 262 | + |
| 263 | + |
| 264 | +@_ods_cext.register_operation(_Dialect, replace=True) |
| 265 | +class LaunchOp(LaunchOp): |
| 266 | + __doc__ = LaunchOp.__doc__ |
| 267 | + |
| 268 | + def __init__( |
| 269 | + self, |
| 270 | + grid_size: Tuple[Any, Any, Any], |
| 271 | + block_size: Tuple[Any, Any, Any], |
| 272 | + async_dependencies=None, |
| 273 | + dynamic_shared_memory_size: Optional[Value] = None, |
| 274 | + *, |
| 275 | + loc=None, |
| 276 | + ip=None, |
| 277 | + ): |
| 278 | + if async_dependencies is None: |
| 279 | + async_dependencies = [] |
| 280 | + async_token = None |
| 281 | + if len(async_dependencies): |
| 282 | + async_token = gpu_async_token() |
| 283 | + grid_size_x, grid_size_y, grid_size_z = map( |
| 284 | + _convert_literal_to_constant, grid_size |
| 285 | + ) |
| 286 | + block_size_x, block_size_y, block_size_z = map( |
| 287 | + _convert_literal_to_constant, block_size |
| 288 | + ) |
| 289 | + |
| 290 | + super().__init__( |
| 291 | + async_token, |
| 292 | + async_dependencies, |
| 293 | + grid_size_x, |
| 294 | + grid_size_y, |
| 295 | + grid_size_z, |
| 296 | + block_size_x, |
| 297 | + block_size_y, |
| 298 | + block_size_z, |
| 299 | + dynamicSharedMemorySize=dynamic_shared_memory_size, |
| 300 | + loc=loc, |
| 301 | + ip=ip, |
| 302 | + ) |
| 303 | + self.regions[0].blocks.append(*[T.index() for _ in range(12)]) |
| 304 | + |
| 305 | + |
| 306 | +def launch_( |
| 307 | + grid_size: Tuple[Any, Any, Any], |
| 308 | + block_size: Tuple[Any, Any, Any], |
| 309 | + async_dependencies=None, |
| 310 | + dynamic_shared_memory_size: Optional[Value] = None, |
| 311 | + *, |
| 312 | + loc=None, |
| 313 | + ip=None, |
| 314 | +): |
| 315 | + grid_size = tuple(map(_convert_literal_to_constant, grid_size)) |
| 316 | + block_size = tuple(map(_convert_literal_to_constant, block_size)) |
| 317 | + launch_op = LaunchOp( |
| 318 | + grid_size, |
| 319 | + block_size, |
| 320 | + async_dependencies, |
| 321 | + dynamic_shared_memory_size, |
| 322 | + loc=loc, |
| 323 | + ip=ip, |
| 324 | + ) |
| 325 | + return launch_op |
| 326 | + |
| 327 | + |
| 328 | +launch = region_op(launch_, terminator=lambda *_args: terminator()) |
| 329 | + |
| 330 | + |
| 331 | +_printf = printf |
| 332 | + |
| 333 | + |
| 334 | +def printf(format, *args, loc=None, ip=None): |
| 335 | + return _printf(format=format, args=args, loc=loc, ip=ip) |
0 commit comments