Skip to content

Commit e7592ab

Browse files
authored
fix nanobind differences (and others) (#110)
1 parent 8984cf8 commit e7592ab

File tree

17 files changed

+183
-75
lines changed

17 files changed

+183
-75
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,12 @@ jobs:
3434
fail-fast: false
3535
matrix:
3636
os: [ ubuntu-22.04, macos-13, macos-14, windows-2022 ]
37-
py_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
37+
py_version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
3838

3939
exclude:
40-
- os: macos-13
41-
py_version: "3.8"
42-
4340
- os: macos-13
4441
py_version: "3.9"
4542

46-
- os: macos-14
47-
py_version: "3.8"
48-
4943
- os: macos-14
5044
py_version: "3.9"
5145

@@ -174,7 +168,7 @@ jobs:
174168
fail-fast: false
175169
matrix:
176170
os: [ ubuntu-22.04 ]
177-
py_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
171+
py_version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
178172

179173
steps:
180174
- name: Checkout
@@ -189,7 +183,7 @@ jobs:
189183
install: |
190184
191185
apt-get update -q -y
192-
apt-get install -y wget build-essential
186+
apt-get install -y wget build-essential git
193187
194188
mkdir -p ~/miniconda3
195189
wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh -O miniconda.sh

examples/mwe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def pats():
131131
.finalize_memref_to_llvm()
132132
# Convert Func to LLVM (always needed).
133133
.convert_func_to_llvm()
134+
.convert_arith_to_llvm()
135+
.convert_cf_to_llvm()
134136
# Convert Index to LLVM (always needed).
135137
.convert_index_to_llvm()
136138
# Convert remaining unrealized_casts (always needed).

examples/vectorization_e2e.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@
424424
" .finalize_memref_to_llvm()\n",
425425
" # Convert Func to LLVM (always needed).\n",
426426
" .convert_func_to_llvm()\n",
427+
" .convert_arith_to_llvm()\n",
428+
" .convert_cf_to_llvm()\n",
427429
" # Convert Index to LLVM (always needed).\n",
428430
" .convert_index_to_llvm()\n",
429431
" # Convert remaining unrealized_casts (always needed).\n",

mlir/extras/ast/canonicalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def transform_ast(
9898
module_code_o = compile(module, f.__code__.co_filename, "exec")
9999
new_f_code_o = find_func_in_code_object(module_code_o, f.__name__)
100100
n_lines = len(inspect.getsource(f).splitlines())
101-
line_starts = list(findlinestarts(new_f_code_o))
101+
line_starts = list(filter(lambda el: el[1], findlinestarts(new_f_code_o)))
102102
if (
103103
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
104104
> n_lines

mlir/extras/dialects/ext/_shaped_value.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# mixin that requires `is_constant`
12-
class ShapedValue:
12+
def ShapedValue(cls):
1313
@cached_property
1414
def literal_value(self) -> np.ndarray:
1515
if not self.is_constant:
@@ -42,3 +42,22 @@ def n_elements(self) -> int:
4242
@cached_property
4343
def dtype(self) -> Type:
4444
return self._shaped_type.element_type
45+
46+
setattr(cls, "literal_value", literal_value)
47+
cls.literal_value.__set_name__(None, "literal_value")
48+
setattr(cls, "_shaped_type", _shaped_type)
49+
cls._shaped_type.__set_name__(None, "_shaped_type")
50+
51+
setattr(cls, "has_static_shape", has_static_shape)
52+
setattr(cls, "has_rank", has_rank)
53+
54+
setattr(cls, "rank", rank)
55+
cls.rank.__set_name__(None, "rank")
56+
setattr(cls, "shape", shape)
57+
cls.shape.__set_name__(None, "shape")
58+
setattr(cls, "n_elements", n_elements)
59+
cls.n_elements.__set_name__(None, "n_elements")
60+
setattr(cls, "dtype", dtype)
61+
cls.dtype.__set_name__(None, "dtype")
62+
63+
return cls

mlir/extras/dialects/ext/arith.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Tuple, Union
88

99
from bytecode import ConcreteBytecode
10+
from einspect.structs import PyTypeObject
1011

1112
from ...ast.canonicalize import StrictTransformer, Canonicalizer, BytecodePatcher
1213
from ...ast.util import ast_call
@@ -138,7 +139,13 @@ def index_cast(
138139
)
139140

140141

141-
class ArithValueMeta(type(Value)):
142+
nb_meta_cls = type(Value)
143+
144+
_Py_TPFLAGS_BASETYPE = 1 << 10
145+
PyTypeObject.from_object(nb_meta_cls).tp_flags |= _Py_TPFLAGS_BASETYPE
146+
147+
148+
class ArithValueMeta(nb_meta_cls):
142149
"""Metaclass that orchestrates the Python object protocol
143150
(i.e., calling __new__ and __init__) for Indexing dialect extension values
144151
(created using `mlir_value_subclass`).

mlir/extras/dialects/ext/func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def __init__(
192192

193193
def _is_decl(self):
194194
# magic constant found from looking at the code for an empty fn
195-
if sys.version_info.minor == 12:
195+
if sys.version_info.minor == 13:
196+
return self.body_builder.__code__.co_code == b"\x95\x00g\x00"
197+
elif sys.version_info.minor == 12:
196198
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
197199
elif sys.version_info.minor == 11:
198200
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"

mlir/extras/dialects/ext/memref.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def store(
129129

130130

131131
@register_value_caster(MemRefType.static_typeid)
132-
class MemRef(Value, ShapedValue):
132+
@ShapedValue
133+
class MemRef(Value):
133134
def __str__(self):
134135
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
135136

mlir/extras/dialects/ext/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def insert_slice(
109109

110110
# TODO(max): unify vector/memref/tensor
111111
@register_value_caster(RankedTensorType.static_typeid)
112-
class Tensor(ShapedValue, ArithValue):
112+
@ShapedValue
113+
class Tensor(ArithValue):
113114
def __getitem__(self, idx: tuple) -> "Tensor":
114115
loc = get_user_code_loc()
115116

mlir/extras/dialects/ext/vector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515

1616
@register_value_caster(VectorType.static_typeid)
17-
class Vector(ShapedValue, ArithValue):
17+
@ShapedValue
18+
class Vector(ArithValue):
1819
def __getitem__(self, idx: tuple) -> "Vector":
1920
loc = get_user_code_loc()
2021

@@ -105,7 +106,7 @@ def transfer_read(
105106
if isinstance(padding, int):
106107
padding = constant(padding, type=source.type.element_type)
107108
if in_bounds is None:
108-
in_bounds = [None] * len(permutation_map.results)
109+
raise ValueError("in_bounds cannot be None")
109110

110111
return _transfer_read(
111112
vector=vector_t,

0 commit comments

Comments
 (0)