Skip to content

Commit b9f0335

Browse files
authored
fix llvm dialect (#120)
1 parent 56ac9da commit b9f0335

File tree

7 files changed

+98
-5273
lines changed

7 files changed

+98
-5273
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
amdgcn.py
Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,13 @@
1-
import warnings
2-
31
# noinspection PyUnresolvedReferences
42
from .....dialects.llvm import *
5-
from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType
6-
7-
try:
8-
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
9-
from llvm import types_
10-
from llvm.context import context as llvm_context
11-
except ImportError:
12-
warnings.warn(
13-
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
14-
)
3+
from .....ir import Type, Value
154

5+
ValueRef = Value
166

177
def llvm_ptr_t():
188
return Type.parse("!llvm.ptr")
199

20-
21-
def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
22-
if F16Type.isinstance(mlir_type):
23-
return types_.half_type_in_context(llvm_ctx)
24-
if F32Type.isinstance(mlir_type):
25-
return types_.float_type_in_context(llvm_ctx)
26-
if F64Type.isinstance(mlir_type):
27-
return types_.double_type_in_context(llvm_ctx)
28-
if BF16Type.isinstance(mlir_type):
29-
return types_.b_float_type_in_context(llvm_ctx)
30-
if IntegerType.isinstance(mlir_type):
31-
return types_.int_type_in_context(llvm_ctx, mlir_type.width)
32-
33-
raise NotImplementedError(f"{mlir_type} is not supported")
34-
35-
36-
def llvm_type_str_to_mlir_type(llvm_type: str):
37-
if llvm_type.startswith("<"):
38-
return Type.parse(f"vector{llvm_type}")
39-
if llvm_type == "float":
40-
return F32Type.get()
41-
raise NotImplementedError(f"{llvm_type} is not supported")
42-
43-
44-
_call_intrinsic = call_intrinsic
45-
46-
47-
def call_intrinsic(*args, **kwargs):
48-
intr_id = kwargs.pop("intr_id")
49-
intr_name = kwargs.pop("intr_name")
50-
mlir_ret_type = kwargs.pop("return_type", None)
51-
if mlir_ret_type:
52-
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
53-
54-
is_overloaded = kwargs.pop("is_overloaded", None)
55-
if is_overloaded is None:
56-
is_overloaded = intrinsic_is_overloaded(intr_id)
57-
with llvm_context() as ctx:
58-
types = []
59-
if is_overloaded:
60-
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
61-
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)
62-
63-
ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
64-
mlir_ret_type = None
65-
if ret_type_str:
66-
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)
67-
68-
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
69-
70-
71-
call_intrinsic_ = call_intrinsic
72-
73-
from . import amdgcn
10+
try:
11+
from . import amdgcn
12+
except ImportError:
13+
pass

0 commit comments

Comments
 (0)