Skip to content

Commit 3e77eb7

Browse files
authored
memref.global wrapper (#64)
1 parent 0b864dc commit 3e77eb7

File tree

4 files changed

+128
-11
lines changed

4 files changed

+128
-11
lines changed

mlir/extras/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import warnings
23
from contextlib import ExitStack, contextmanager
34
from dataclasses import dataclass
45
from typing import Optional
@@ -58,7 +59,7 @@ def __del__(self):
5859
self.context.__exit__(None, None, None)
5960
# i guess the extension gets destroyed before this object sometimes?
6061
if ir is not None:
61-
assert ir.Context.current is None, str(ir.Context.current)
62+
assert ir.Context is not self.context
6263

6364

6465
class ExplicitlyManagedModule:

mlir/extras/dialects/ext/memref.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1+
import inspect
12
from functools import cached_property, reduce
23
from typing import Tuple, Sequence, Union
34

5+
import numpy as np
6+
47
from .arith import Scalar, constant
58
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
69
from ... import types as T
710
from ...meta import region_op
8-
from ...util import get_user_code_loc, _unpack_sizes_element_type
11+
from ...util import (
12+
get_user_code_loc,
13+
_unpack_sizes_element_type,
14+
_get_sym_name,
15+
infer_mlir_type,
16+
)
917
from ...._mlir_libs._mlir import register_value_caster
1018
from ....dialects import memref, arith
1119
from ....dialects._ods_common import get_op_result_or_op_results
1220
from ....dialects.memref import *
13-
from ....ir import Type, Value, MemRefType, ShapedType
21+
from ....ir import Type, Value, MemRefType, ShapedType, DenseElementsAttr
1422

1523
S = ShapedType.get_dynamic_size()
1624

@@ -286,3 +294,44 @@ def dim(source, index, *, loc=None, ip=None):
286294
if isinstance(index, int):
287295
index = constant(index, index=True)
288296
return _dim(source=source, index=index, loc=loc, ip=ip)
297+
298+
299+
def global_(
300+
initial_value=None,
301+
sym_name=None,
302+
type_=None,
303+
sym_visibility="private",
304+
constant=None,
305+
alignment=None,
306+
loc=None,
307+
ip=None,
308+
):
309+
if sym_name is None:
310+
previous_frame = inspect.currentframe().f_back
311+
sym_name = _get_sym_name(
312+
previous_frame, check_func_call="memref\.global_|global_"
313+
)
314+
if loc is None:
315+
loc = get_user_code_loc()
316+
if initial_value is None:
317+
assert type_ is not None
318+
else:
319+
assert isinstance(initial_value, np.ndarray)
320+
type_ = infer_mlir_type(initial_value, memref=True)
321+
initial_value = DenseElementsAttr.get(
322+
initial_value,
323+
type=type_.element_type,
324+
context=None,
325+
)
326+
constant = True
327+
328+
return memref.global_(
329+
sym_name,
330+
type_,
331+
sym_visibility=sym_visibility,
332+
initial_value=initial_value,
333+
constant=constant,
334+
alignment=alignment,
335+
loc=loc,
336+
ip=ip,
337+
).opview

mlir/extras/util.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes
33
import inspect
44
import platform
5+
import re
56
import sys
67
import warnings
78
from dataclasses import dataclass
@@ -23,6 +24,7 @@
2324
Location,
2425
OpView,
2526
Operation,
27+
MemRefType,
2628
RankedTensorType,
2729
Value,
2830
_GlobalDebug,
@@ -165,7 +167,7 @@ def mlir_type_to_ctype(mlir_type):
165167

166168

167169
def infer_mlir_type(
168-
py_val: Union[int, float, bool, np.ndarray]
170+
py_val: Union[int, float, bool, np.ndarray], memref=False
169171
) -> Union[IntegerType, F32Type, F64Type, RankedTensorType]:
170172
"""Infer MLIR type (`ir.Type`) from supported python values.
171173
@@ -202,7 +204,10 @@ def infer_mlir_type(
202204
return T.f64()
203205
elif isinstance(py_val, np.ndarray):
204206
dtype = np_dtype_to_mlir_type(py_val.dtype.type)
205-
return RankedTensorType.get(py_val.shape, dtype)
207+
if memref:
208+
return MemRefType.get(py_val.shape, dtype)
209+
else:
210+
return RankedTensorType.get(py_val.shape, dtype)
206211
else:
207212
raise NotImplementedError(
208213
f"Unsupported Python value {py_val=} with type {type(py_val)}"
@@ -356,8 +361,8 @@ def _get_sym_name(previous_frame, check_func_call=None):
356361
src_lines = src_file.readlines()
357362
src_line = src_lines[previous_frame.f_lineno - 1].strip()
358363
ident, func_call = map(lambda x: x.strip(), src_line.split("=", maxsplit=1))
359-
if check_func_call is None:
360-
assert check_func_call in func_call
364+
if check_func_call is not None:
365+
assert re.match(check_func_call, func_call)
361366
return ident
362367
except:
363368
return None

tests/test_memref.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
import numpy as np
1+
import platform
22
import re
33
from textwrap import dedent
44

5+
import numpy as np
56
import pytest
6-
from mlir.ir import MLIRError, Type
77

88
import mlir.extras.types as T
9+
from mlir.dialects.memref import subview
910
from mlir.extras.ast.canonicalize import canonicalize
11+
from mlir.extras.dialects.ext import memref
1012
from mlir.extras.dialects.ext.arith import Scalar, constant
1113
from mlir.extras.dialects.ext.memref import (
1214
alloc,
1315
alloca,
14-
S,
1516
alloca_scope,
1617
alloca_scope_return,
18+
global_,
1719
)
1820
from mlir.extras.dialects.ext.scf import (
1921
range_,
2022
yield_,
2123
canonicalizer,
2224
)
23-
from mlir.dialects.memref import subview
2425

2526
# noinspection PyUnresolvedReferences
2627
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
28+
from mlir.ir import MLIRError, Type
2729

2830
# needed since the fix isn't defined here nor conftest.py
2931
pytest.mark.usefixtures("ctx")
@@ -581,3 +583,63 @@ def tenfoo():
581583
)
582584

583585
filecheck(correct, ctx.module)
586+
587+
588+
@pytest.mark.skipif(
589+
platform.system() == "Windows",
590+
reason="On windows int64 is inferred to be i64 ",
591+
)
592+
def test_memref_global_windows(ctx: MLIRContext):
593+
k = 32
594+
weight1 = global_(np.ones((k,), dtype=np.int32))
595+
weight2 = global_(np.ones((k,), dtype=np.int64))
596+
weight3 = global_(np.ones((k,), dtype=np.float32))
597+
weight4 = global_(np.ones((k,), dtype=np.float64))
598+
weight5 = memref.global_(np.ones((k,), dtype=np.int16))
599+
weight6 = memref.global_(np.ones((k,), dtype=np.float16))
600+
print(ctx.module)
601+
602+
correct = dedent(
603+
"""\
604+
module {
605+
memref.global "private" constant @weight1 : memref<32xi32> = dense<1>
606+
memref.global "private" constant @weight2 : memref<32xi64> = dense<1>
607+
memref.global "private" constant @weight3 : memref<32xf32> = dense<1.000000e+00>
608+
memref.global "private" constant @weight4 : memref<32xf64> = dense<1.000000e+00>
609+
memref.global "private" constant @weight5 : memref<32xi16> = dense<1>
610+
memref.global "private" constant @weight6 : memref<32xf16> = dense<1.000000e+00>
611+
}
612+
"""
613+
)
614+
615+
filecheck(correct, ctx.module)
616+
617+
618+
@pytest.mark.skipif(
619+
platform.system() != "Windows",
620+
reason="On linux/mac int64 is inferred to be index (through np.longlong)",
621+
)
622+
def test_memref_global_non_windows(ctx: MLIRContext):
623+
k = 32
624+
weight1 = global_(np.ones((k,), dtype=np.int32))
625+
weight2 = global_(np.ones((k,), dtype=np.int64))
626+
weight3 = global_(np.ones((k,), dtype=np.float32))
627+
weight4 = global_(np.ones((k,), dtype=np.float64))
628+
weight5 = memref.global_(np.ones((k,), dtype=np.int16))
629+
weight6 = memref.global_(np.ones((k,), dtype=np.float16))
630+
print(ctx.module)
631+
632+
correct = dedent(
633+
"""\
634+
module {
635+
memref.global "private" constant @weight1 : memref<32xi32> = dense<1>
636+
memref.global "private" constant @weight2 : memref<32xindex> = dense<1>
637+
memref.global "private" constant @weight3 : memref<32xf32> = dense<1.000000e+00>
638+
memref.global "private" constant @weight4 : memref<32xf64> = dense<1.000000e+00>
639+
memref.global "private" constant @weight5 : memref<32xi16> = dense<1>
640+
memref.global "private" constant @weight6 : memref<32xf16> = dense<1.000000e+00>
641+
}
642+
"""
643+
)
644+
645+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)