Skip to content

Commit a14994f

Browse files
authored
arith constant vec (#65)
1 parent 3e77eb7 commit a14994f

26 files changed

+203
-210
lines changed

examples/mwe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# noinspection PyUnresolvedReferences
1212
import mlir.extras.dialects.ext.memref
1313
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
14-
from mlir.extras.dialects.ext.bufferization import LayoutMapOption
14+
from mlir.dialects.bufferization import LayoutMapOption
1515
from mlir.dialects.transform.vector import (
1616
VectorContractLowering,
1717
VectorMultiReductionLowering,

examples/vectorization_e2e.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@
5757
"from mlir.ir import StringAttr, UnitAttr\n",
5858
"\n",
5959
"# you need this to register the memref value caster\n",
60+
"# noinspection PyUnresolvedReferences\n",
6061
"import mlir.extras.dialects.ext.memref\n",
6162
"from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule\n",
62-
"from mlir.extras.dialects.ext.bufferization import LayoutMapOption\n",
63+
"from mlir.dialects.bufferization import LayoutMapOption\n",
6364
"from mlir.dialects.transform.vector import (\n",
6465
" VectorContractLowering,\n",
6566
" VectorMultiReductionLowering,\n",

mlir/extras/context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import warnings
32
from contextlib import ExitStack, contextmanager
43
from dataclasses import dataclass
54
from typing import Optional
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from functools import cached_property, reduce
2+
from typing import Tuple
3+
4+
import numpy as np
5+
6+
from ....ir import DenseElementsAttr, ShapedType, Type
7+
8+
S = ShapedType.get_dynamic_size()
9+
10+
11+
# mixin that requires `is_constant`
12+
class ShapedValue:
13+
@cached_property
14+
def literal_value(self) -> np.ndarray:
15+
if not self.is_constant:
16+
raise ValueError("Can't build literal from non-constant value")
17+
return np.array(DenseElementsAttr(self.owner.opview.value), copy=False)
18+
19+
@cached_property
20+
def _shaped_type(self) -> ShapedType:
21+
return ShapedType(self.type)
22+
23+
def has_static_shape(self) -> bool:
24+
return self._shaped_type.has_static_shape
25+
26+
def has_rank(self) -> bool:
27+
return self._shaped_type.has_rank
28+
29+
@cached_property
30+
def shape(self) -> Tuple[int, ...]:
31+
return tuple(self._shaped_type.shape)
32+
33+
@cached_property
34+
def n_elements(self) -> int:
35+
assert self.has_static_shape()
36+
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)
37+
38+
@cached_property
39+
def dtype(self) -> Type:
40+
return self._shaped_type.element_type

mlir/extras/dialects/ext/arith.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,44 @@
11
import operator
22
from abc import abstractmethod
33
from copy import deepcopy
4-
from functools import partialmethod, cached_property
5-
from typing import Union, Optional, Tuple
6-
7-
import numpy as np
4+
from functools import cached_property, partialmethod
5+
from typing import Optional, Tuple
86

97
from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
108
from ...._mlir_libs._mlir import register_value_caster
11-
from ....dialects import arith as arith_dialect
12-
from ....dialects import complex as complex_dialect
9+
from ....dialects import arith as arith_dialect, complex as complex_dialect
1310
from ....dialects._arith_enum_gen import (
1411
_arith_cmpfpredicateattr,
15-
CmpFPredicate,
16-
CmpIPredicate,
1712
_arith_cmpipredicateattr,
1813
)
19-
from ....dialects._ods_common import get_op_result_or_value, get_op_result_or_op_results
14+
from ....dialects._ods_common import get_op_result_or_op_results, get_op_result_or_value
2015
from ....dialects.arith import *
2116
from ....dialects.arith import _is_integer_like_type
2217
from ....dialects.linalg.opdsl.lang.emitter import (
23-
_is_floating_point_type,
24-
_is_integer_type,
2518
_is_complex_type,
19+
_is_floating_point_type,
2620
_is_index_type,
2721
)
2822
from ....ir import (
2923
Attribute,
24+
BF16Type,
25+
ComplexType,
3026
Context,
3127
DenseElementsAttr,
28+
F16Type,
29+
F32Type,
30+
F64Type,
31+
FloatAttr,
3232
IndexType,
3333
InsertionPoint,
3434
IntegerType,
3535
Location,
3636
OpView,
3737
Operation,
38-
RankedTensorType,
38+
ShapedType,
3939
Type,
4040
Value,
4141
register_attribute_builder,
42-
ComplexType,
43-
BF16Type,
44-
F16Type,
45-
F32Type,
46-
F64Type,
47-
FloatAttr,
4842
)
4943

5044

@@ -53,6 +47,7 @@ def constant(
5347
type: Optional[Type] = None,
5448
index: Optional[bool] = None,
5549
*,
50+
vector: Optional[bool] = False,
5651
loc: Location = None,
5752
ip: InsertionPoint = None,
5853
) -> Value:
@@ -75,7 +70,7 @@ def constant(
7570
if index is not None and index:
7671
type = IndexType.get()
7772
if type is None:
78-
type = infer_mlir_type(value)
73+
type = infer_mlir_type(value, vector=vector)
7974

8075
assert type is not None
8176

@@ -98,8 +93,8 @@ def constant(
9893
if _is_floating_point_type(type) and not isinstance(value, np.ndarray):
9994
value = float(value)
10095

101-
if RankedTensorType.isinstance(type) and isinstance(value, (int, float, bool)):
102-
ranked_tensor_type = RankedTensorType(type)
96+
if ShapedType.isinstance(type) and isinstance(value, (int, float, bool)):
97+
ranked_tensor_type = ShapedType(type)
10398
value = np.full(
10499
ranked_tensor_type.shape,
105100
value,
@@ -403,7 +398,7 @@ def __str__(self):
403398
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
404399

405400
def __repr__(self):
406-
return str(self)
401+
return str(Value(self)).replace("Value", self.__class__.__name__)
407402

408403
# partialmethod differs from partial in that it also binds the object instance
409404
# to the first arg (i.e., self)
@@ -468,16 +463,6 @@ class Scalar(ArithValue):
468463
def dtype(self) -> Type:
469464
return self.type
470465

471-
@staticmethod
472-
def isinstance(other: Value):
473-
return (
474-
isinstance(other, Value)
475-
and _is_integer_type(other.type)
476-
or _is_floating_point_type(other.type)
477-
or _is_index_type(other.type)
478-
or _is_complex_type(other.type)
479-
)
480-
481466
@cached_property
482467
def literal_value(self) -> Union[int, float, bool]:
483468
if not self.is_constant():

mlir/extras/dialects/ext/bufferization.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

mlir/extras/dialects/ext/cf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Union, List
1+
from typing import List, Union
22

3-
from ...util import get_user_code_loc, Successor
3+
from ...util import Successor, get_user_code_loc
44
from ....dialects._cf_ops_gen import _Dialect
55
from ....dialects._ods_common import (
66
_cext,
77
)
88
from ....dialects.cf import *
9-
from ....ir import Value, InsertionPoint, Block
9+
from ....ir import Block, InsertionPoint, Value
1010

1111

1212
@_cext.register_operation(_Dialect, replace=True)

mlir/extras/dialects/ext/func.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33

44
from ...meta import op_region_builder
55
from ...util import get_user_code_loc, make_maybe_no_args_decorator
6-
from ....dialects.func import *
76
from ....dialects._ods_common import get_op_result_or_op_results
7+
from ....dialects.func import *
88
from ....ir import (
99
FlatSymbolRefAttr,
1010
FunctionType,
1111
InsertionPoint,
12+
OpView,
13+
Operation,
1214
Type,
1315
TypeAttr,
1416
Value,
15-
Operation,
16-
OpView,
1717
)
1818

1919

mlir/extras/dialects/ext/gpu.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
11
import inspect
22
from functools import partial
3-
from typing import Optional, Any, List, Tuple
3+
from typing import Any, List, Optional, Tuple
44

55
from .arith import constant
66
from .func import FuncBase
77
from ... import types as T
88
from ...meta import (
99
region_op,
1010
)
11-
from ...util import get_user_code_loc, make_maybe_no_args_decorator, ModuleMeta
11+
from ...util import ModuleMeta, get_user_code_loc, make_maybe_no_args_decorator
1212
from ....dialects._gpu_ops_gen import _Dialect
13-
from ....dialects._ods_common import get_default_loc_context, _cext
14-
from ....dialects._ods_common import get_op_result_or_op_results
13+
from ....dialects._ods_common import (
14+
_cext,
15+
get_default_loc_context,
16+
get_op_result_or_op_results,
17+
)
1518
from ....dialects.gpu import *
1619
from ....ir import (
17-
Type,
18-
Attribute,
20+
ArrayAttr,
1921
AttrBuilder,
20-
UnitAttr,
21-
register_attribute_builder,
22+
Attribute,
2223
Context,
23-
ArrayAttr,
2424
InsertionPoint,
25+
Type,
26+
UnitAttr,
2527
Value,
28+
register_attribute_builder,
2629
)
2730

2831

mlir/extras/dialects/ext/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from ...util import get_user_code_loc
2+
3+
# noinspection PyUnresolvedReferences
4+
from ....dialects.linalg import *
25
from ....dialects import linalg
36

47

0 commit comments

Comments
 (0)