Skip to content

Commit 4d77301

Browse files
authored
amdgcn for llvm dialect (#116)
1 parent 01a6b17 commit 4d77301

File tree

7 files changed

+5332
-12
lines changed

7 files changed

+5332
-12
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ concurrency:
2323

2424
env:
2525
SYSTEM_VERSION_COMPAT: 0
26+
PIP_FIND_LINKS: "https://github.com/llvm/eudsl/releases/expanded_assets/latest https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest"
2627

2728
jobs:
2829

@@ -43,6 +44,8 @@ jobs:
4344
- os: macos-14
4445
py_version: "3.9"
4546

47+
name: "${{ matrix.os }}-${{ matrix.py_version }}"
48+
4649
steps:
4750
- name: Checkout
4851
uses: actions/checkout@v2
@@ -56,7 +59,7 @@ jobs:
5659
- name: Install and configure
5760
shell: bash
5861
run: |
59-
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
62+
pip install .[test,mlir] -v
6063
6164
- name: Test
6265
shell: bash
@@ -95,7 +98,6 @@ jobs:
9598
- name: Install and configure
9699
shell: bash
97100
run: |
98-
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
99101
pip install .[test,mlir] -v
100102
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
101103
@@ -133,7 +135,7 @@ jobs:
133135
run: |
134136
135137
pip install jupyter
136-
pip install -q mlir-python-bindings -f https://makslevental.github.io/wheels
138+
pip install -q mlir-python-bindings
137139
pip install -q .
138140
139141
sed -i.bak 's/OUTPUT_TIMEOUT = 10/OUTPUT_TIMEOUT = 100/g' \

mlir/extras/dialects/ext/llvm.py

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import warnings
2+
3+
# noinspection PyUnresolvedReferences
4+
from .....dialects.llvm import *
5+
from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType
6+
7+
try:
8+
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
9+
from llvm import types_
10+
from llvm.context import context as llvm_context
11+
except ImportError:
12+
warnings.warn(
13+
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
14+
)
15+
16+
17+
def llvm_ptr_t():
18+
return Type.parse("!llvm.ptr")
19+
20+
21+
def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
22+
if F16Type.isinstance(mlir_type):
23+
return types_.half_type_in_context(llvm_ctx)
24+
if F32Type.isinstance(mlir_type):
25+
return types_.float_type_in_context(llvm_ctx)
26+
if F64Type.isinstance(mlir_type):
27+
return types_.double_type_in_context(llvm_ctx)
28+
if BF16Type.isinstance(mlir_type):
29+
return types_.b_float_type_in_context(llvm_ctx)
30+
if IntegerType.isinstance(mlir_type):
31+
return types_.int_type_in_context(llvm_ctx, mlir_type.width)
32+
33+
raise NotImplementedError(f"{mlir_type} is not supported")
34+
35+
36+
def llvm_type_str_to_mlir_type(llvm_type: str):
37+
if llvm_type.startswith("<"):
38+
return Type.parse(f"vector{llvm_type}")
39+
if llvm_type == "float":
40+
return F32Type.get()
41+
raise NotImplementedError(f"{llvm_type} is not supported")
42+
43+
44+
_call_intrinsic = call_intrinsic
45+
46+
47+
def call_intrinsic(*args, **kwargs):
48+
intr_id = kwargs.pop("intr_id")
49+
intr_name = kwargs.pop("intr_name")
50+
mlir_ret_type = kwargs.pop("return_type", None)
51+
if mlir_ret_type:
52+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
53+
54+
is_overloaded = kwargs.pop("is_overloaded", None)
55+
if is_overloaded is None:
56+
is_overloaded = intrinsic_is_overloaded(intr_id)
57+
with llvm_context() as ctx:
58+
types = []
59+
if is_overloaded:
60+
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
61+
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)
62+
63+
ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
64+
mlir_ret_type = None
65+
if ret_type_str:
66+
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)
67+
68+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
69+
70+
71+
call_intrinsic_ = call_intrinsic
72+
73+
from . import amdgcn

0 commit comments

Comments
 (0)