Skip to content

Commit ade0e39

Browse files
authored
add vec ops (#67)
1 parent 034302b commit ade0e39

File tree

3 files changed

+215
-9
lines changed

3 files changed

+215
-9
lines changed

mlir/extras/dialects/ext/_shaped_value.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def has_static_shape(self) -> bool:
2626
def has_rank(self) -> bool:
2727
return self._shaped_type.has_rank
2828

29+
@cached_property
30+
def rank(self) -> int:
31+
return self._shaped_type.rank
32+
2933
@cached_property
3034
def shape(self) -> Tuple[int, ...]:
3135
return tuple(self._shaped_type.shape)

mlir/extras/dialects/ext/vector.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,167 @@
1+
from typing import List
2+
13
from ._shaped_value import ShapedValue
2-
from .arith import ArithValue
4+
from .arith import ArithValue, FastMathFlags, constant
5+
from ...util import get_user_code_loc
36
from ...._mlir_libs._mlir import register_value_caster
7+
from ....dialects._ods_common import _dispatch_mixed_values
48

59
# noinspection PyUnresolvedReferences
610
from ....dialects.vector import *
7-
from ....ir import VectorType
11+
from ....extras import types as T
12+
from ....ir import AffineMap, VectorType
813

914

1015
@register_value_caster(VectorType.static_typeid)
1116
class Vector(ShapedValue, ArithValue):
1217
pass
18+
19+
20+
_transfer_write = transfer_write
21+
22+
23+
def transfer_write(
24+
vector: Vector,
25+
source,
26+
indices,
27+
*,
28+
permutation_map=None,
29+
mask: List[int] = None,
30+
in_bounds: List[bool] = None,
31+
loc=None,
32+
ip=None
33+
):
34+
if loc is None:
35+
loc = get_user_code_loc()
36+
if permutation_map is None:
37+
permutation_map = AffineMap.get_minor_identity(
38+
source.type.rank, vector.type.rank
39+
)
40+
for j, i in enumerate(indices):
41+
if isinstance(i, int):
42+
indices[j] = constant(i, index=True)
43+
return _transfer_write(
44+
result=None,
45+
vector=vector,
46+
source=source,
47+
indices=indices,
48+
permutation_map=permutation_map,
49+
mask=mask,
50+
in_bounds=in_bounds,
51+
loc=loc,
52+
ip=ip,
53+
)
54+
55+
56+
_transfer_read = transfer_read
57+
58+
59+
def transfer_read(
60+
vector_t,
61+
source,
62+
indices,
63+
*,
64+
permutation_map=None,
65+
padding=None,
66+
mask=None,
67+
in_bounds=None,
68+
loc=None,
69+
ip=None
70+
):
71+
if loc is None:
72+
loc = get_user_code_loc()
73+
if permutation_map is None:
74+
permutation_map = AffineMap.get_minor_identity(source.type.rank, vector_t.rank)
75+
for j, i in enumerate(indices):
76+
if isinstance(i, int):
77+
indices[j] = constant(i, index=True)
78+
if padding is None:
79+
padding = 0
80+
if isinstance(padding, int):
81+
padding = constant(padding, type=source.type.element_type)
82+
83+
return _transfer_read(
84+
vector=vector_t,
85+
source=source,
86+
indices=indices,
87+
permutation_map=permutation_map,
88+
padding=padding,
89+
mask=mask,
90+
in_bounds=in_bounds,
91+
loc=loc,
92+
ip=ip,
93+
)
94+
95+
96+
_extract = extract
97+
98+
99+
def extract(vector, position, *, loc=None, ip=None):
100+
if loc is None:
101+
loc = get_user_code_loc()
102+
dynamic_position, _packed_position, static_position = _dispatch_mixed_values(
103+
position
104+
)
105+
return _extract(
106+
vector=vector,
107+
dynamic_position=dynamic_position,
108+
static_position=static_position,
109+
loc=loc,
110+
ip=ip,
111+
)
112+
113+
114+
_reduction = reduction
115+
116+
117+
def reduction(
118+
kind: CombiningKind,
119+
vector,
120+
*,
121+
acc=None,
122+
fastmath: FastMathFlags = None,
123+
loc=None,
124+
ip=None
125+
):
126+
if loc is None:
127+
loc = get_user_code_loc()
128+
dest = vector.type.element_type
129+
return _reduction(
130+
dest=dest,
131+
kind=kind,
132+
vector=vector,
133+
acc=acc,
134+
fastmath=fastmath,
135+
loc=loc,
136+
ip=ip,
137+
)
138+
139+
140+
_broadcast = broadcast
141+
142+
143+
def broadcast(vector, source, *, loc=None, ip=None):
144+
if loc is None:
145+
loc = get_user_code_loc()
146+
if isinstance(source, (float, int, bool)):
147+
source = constant(source)
148+
return _broadcast(vector=vector, source=source, loc=loc, ip=ip)
149+
150+
151+
_extract_strided_slice = extract_strided_slice
152+
153+
154+
def extract_strided_slice(vector, offsets, sizes, strides, *, loc=None, ip=None):
155+
if loc is None:
156+
loc = get_user_code_loc()
157+
result_shape = [int(s) for s in sizes] + vector.type.shape[len(sizes) :]
158+
result = T.vector(*result_shape, vector.type.element_type)
159+
return _extract_strided_slice(
160+
result=result,
161+
vector=vector,
162+
offsets=offsets,
163+
sizes=sizes,
164+
strides=strides,
165+
loc=loc,
166+
ip=ip,
167+
)

tests/test_vector.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import sys
1+
from textwrap import dedent
22

33
import numpy as np
44
import pytest
55

6-
# you need this to register the memref value caster
7-
# noinspection PyUnresolvedReferences
8-
import mlir.extras.dialects.ext.memref
96
from mlir.dialects import builtin
7+
from mlir.dialects.bufferization import LayoutMapOption
108
from mlir.dialects.transform import (
119
any_op_t,
1210
)
@@ -19,8 +17,10 @@
1917
)
2018
from mlir.extras import types as T
2119
from mlir.extras.context import ExplicitlyManagedModule, RAIIMLIRContext
22-
from mlir.extras.dialects.ext import linalg, transform, arith, vector
23-
from mlir.dialects.bufferization import LayoutMapOption
20+
21+
# you need this to register the memref value caster
22+
# noinspection PyUnresolvedReferences
23+
from mlir.extras.dialects.ext import arith, linalg, memref, transform, vector
2424
from mlir.extras.dialects.ext.func import func
2525
from mlir.extras.dialects.ext.transform import (
2626
get_parent_op,
@@ -154,7 +154,6 @@ def pats():
154154
kernel_name=smol_matmul.__name__,
155155
pipeline=lower_to_llvm,
156156
)
157-
print(compiled_module)
158157

159158
A = np.random.randint(0, 10, (M, K)).astype(np.float32)
160159
B = np.random.randint(0, 10, (K, N)).astype(np.float32)
@@ -172,3 +171,51 @@ def test_np_constructor(ctx: MLIRContext):
172171
repr(vec)
173172
== f"Vector(%cst = arith.constant dense<{vec.literal_value.tolist()}> : vector<2x4xi32>)"
174173
)
174+
175+
176+
def test_vector_wrappers(ctx: MLIRContext):
177+
M, K, N = 2, 4, 6
178+
mem = memref.alloc(M, K, N, T.i32())
179+
vec = vector.transfer_read(T.vector(M, K, T.i32()), mem, [0, 0, 0], padding=5)
180+
e_vec = vector.extract(vec, [0])
181+
vector.transfer_write(e_vec, mem, [0, 0, 0], in_bounds=[True])
182+
183+
b = vector.broadcast(T.vector(10, T.i32()), 5)
184+
r = vector.reduction(vector.CombiningKind.ADD, b)
185+
186+
b = vector.broadcast(T.vector(4, 8, 16, 32, T.i32()), 5)
187+
acc = vector.broadcast(T.vector(4, 16, T.i32()), 0)
188+
r = vector.multi_reduction(vector.CombiningKind.ADD, b, acc, [1, 3])
189+
190+
b = vector.broadcast(T.vector(4, 8, 16, T.i32()), 5)
191+
e = vector.extract_strided_slice(b, [0, 2], [2, 4], [1, 1])
192+
193+
correct = dedent(
194+
"""\
195+
module {
196+
%alloc = memref.alloc() : memref<2x4x6xi32>
197+
%c0 = arith.constant 0 : index
198+
%c0_0 = arith.constant 0 : index
199+
%c0_1 = arith.constant 0 : index
200+
%c5_i32 = arith.constant 5 : i32
201+
%0 = vector.transfer_read %alloc[%c0, %c0_0, %c0_1], %c5_i32 : memref<2x4x6xi32>, vector<2x4xi32>
202+
%1 = vector.extract %0[0] : vector<4xi32> from vector<2x4xi32>
203+
%c0_2 = arith.constant 0 : index
204+
%c0_3 = arith.constant 0 : index
205+
%c0_4 = arith.constant 0 : index
206+
vector.transfer_write %1, %alloc[%c0_2, %c0_3, %c0_4] {in_bounds = [true]} : vector<4xi32>, memref<2x4x6xi32>
207+
%c5_i32_5 = arith.constant 5 : i32
208+
%2 = vector.broadcast %c5_i32_5 : i32 to vector<10xi32>
209+
%3 = vector.reduction <add>, %2 : vector<10xi32> into i32
210+
%c5_i32_6 = arith.constant 5 : i32
211+
%4 = vector.broadcast %c5_i32_6 : i32 to vector<4x8x16x32xi32>
212+
%c0_i32 = arith.constant 0 : i32
213+
%5 = vector.broadcast %c0_i32 : i32 to vector<4x16xi32>
214+
%6 = vector.multi_reduction <add>, %4, %5 [1, 3] : vector<4x8x16x32xi32> to vector<4x16xi32>
215+
%c5_i32_7 = arith.constant 5 : i32
216+
%7 = vector.broadcast %c5_i32_7 : i32 to vector<4x8x16xi32>
217+
%8 = vector.extract_strided_slice %7 {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]} : vector<4x8x16xi32> to vector<2x4x16xi32>
218+
}
219+
"""
220+
)
221+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)