|
1 |
| -import warnings |
2 |
| - |
3 | 1 | # noinspection PyUnresolvedReferences
|
4 | 2 | from .....dialects.llvm import *
|
5 |
| -from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType |
6 |
| - |
7 |
| -try: |
8 |
| - from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string |
9 |
| - from llvm import types_ |
10 |
| - from llvm.context import context as llvm_context |
11 |
| -except ImportError: |
12 |
| - warnings.warn( |
13 |
| - "llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly" |
14 |
| - ) |
| 3 | +from .....ir import Type, Value |
15 | 4 |
|
| 5 | +ValueRef = Value |
16 | 6 |
|
17 | 7 | def llvm_ptr_t():
|
18 | 8 | return Type.parse("!llvm.ptr")
|
19 | 9 |
|
20 |
| - |
21 |
| -def mlir_type_to_llvm_type(mlir_type, llvm_ctx): |
22 |
| - if F16Type.isinstance(mlir_type): |
23 |
| - return types_.half_type_in_context(llvm_ctx) |
24 |
| - if F32Type.isinstance(mlir_type): |
25 |
| - return types_.float_type_in_context(llvm_ctx) |
26 |
| - if F64Type.isinstance(mlir_type): |
27 |
| - return types_.double_type_in_context(llvm_ctx) |
28 |
| - if BF16Type.isinstance(mlir_type): |
29 |
| - return types_.b_float_type_in_context(llvm_ctx) |
30 |
| - if IntegerType.isinstance(mlir_type): |
31 |
| - return types_.int_type_in_context(llvm_ctx, mlir_type.width) |
32 |
| - |
33 |
| - raise NotImplementedError(f"{mlir_type} is not supported") |
34 |
| - |
35 |
| - |
36 |
| -def llvm_type_str_to_mlir_type(llvm_type: str): |
37 |
| - if llvm_type.startswith("<"): |
38 |
| - return Type.parse(f"vector{llvm_type}") |
39 |
| - if llvm_type == "float": |
40 |
| - return F32Type.get() |
41 |
| - raise NotImplementedError(f"{llvm_type} is not supported") |
42 |
| - |
43 |
| - |
44 |
| -_call_intrinsic = call_intrinsic |
45 |
| - |
46 |
| - |
47 |
| -def call_intrinsic(*args, **kwargs): |
48 |
| - intr_id = kwargs.pop("intr_id") |
49 |
| - intr_name = kwargs.pop("intr_name") |
50 |
| - mlir_ret_type = kwargs.pop("return_type", None) |
51 |
| - if mlir_ret_type: |
52 |
| - return _call_intrinsic(mlir_ret_type, intr_name, args, [], []) |
53 |
| - |
54 |
| - is_overloaded = kwargs.pop("is_overloaded", None) |
55 |
| - if is_overloaded is None: |
56 |
| - is_overloaded = intrinsic_is_overloaded(intr_id) |
57 |
| - with llvm_context() as ctx: |
58 |
| - types = [] |
59 |
| - if is_overloaded: |
60 |
| - types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args] |
61 |
| - intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types) |
62 |
| - |
63 |
| - ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip() |
64 |
| - mlir_ret_type = None |
65 |
| - if ret_type_str: |
66 |
| - mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str) |
67 |
| - |
68 |
| - return _call_intrinsic(mlir_ret_type, intr_name, args, [], []) |
69 |
| - |
70 |
| - |
71 |
| -call_intrinsic_ = call_intrinsic |
72 |
| - |
73 |
| -from . import amdgcn |
| 10 | +try: |
| 11 | + from . import amdgcn |
| 12 | +except ImportError: |
| 13 | + pass |
0 commit comments