Skip to content

Commit 034302b

Browse files
authored
fix linalg fill (#66)
1 parent a14994f commit 034302b

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Practically speaking that means you need to have *some* package installed that i
133133
So
134134

135135
```shell
136-
$ YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-extras
136+
$ HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-extras
137137
```
138138

139139
where `YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the package prefix for your chosen host bindings.

mlir/extras/dialects/ext/linalg.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from . import arith
12
from ...util import get_user_code_loc
2-
3+
from ....dialects import linalg
34
# noinspection PyUnresolvedReferences
45
from ....dialects.linalg import *
5-
from ....dialects import linalg
6+
from ....extras import types as T
67

78

89
def abs(I, O, *, loc=None, ip=None):
@@ -263,16 +264,25 @@ def exp(I, O, *, loc=None, ip=None):
263264
return linalg.exp(I, loc=loc, ip=ip, outs=[O])
264265

265266

266-
def fill(O, *, loc=None, ip=None):
267+
def fill(v, O, *, loc=None, ip=None):
268+
if isinstance(v, (float, int, bool)):
269+
v = arith.constant(v)
267270
if loc is None:
268271
loc = get_user_code_loc()
269-
return linalg.fill(loc=loc, ip=ip, outs=[O])
272+
return linalg.fill(v, loc=loc, ip=ip, outs=[O])
270273

271274

272-
def fill_rng_2d(O, *, loc=None, ip=None):
275+
def fill_rng_2d(min, max, seed, O, *, loc=None, ip=None):
276+
params = [min, max]
277+
for i, m in enumerate(params):
278+
if isinstance(m, (float, int)):
279+
params[i] = arith.constant(m, type=T.f64())
280+
min, max = params
281+
if isinstance(seed, int):
282+
seed = arith.constant(seed, T.i32())
273283
if loc is None:
274284
loc = get_user_code_loc()
275-
return linalg.fill_rng_2d(loc=loc, ip=ip, outs=[O])
285+
return linalg.fill_rng_2d(min, max, seed, loc=loc, ip=ip, outs=[O])
276286

277287

278288
def floor(I, O, *, loc=None, ip=None):

tests/test_linalg.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
import mlir.extras.types as T
6+
from mlir.extras.dialects.ext import linalg, memref, tensor
7+
8+
# noinspection PyUnresolvedReferences
9+
from mlir.extras.testing import MLIRContext, filecheck, mlir_ctx as ctx
10+
11+
# needed since the fix isn't defined here nor conftest.py
12+
pytest.mark.usefixtures("ctx")
13+
14+
15+
def test_np_constructor(ctx: MLIRContext):
16+
x = memref.alloc(10, 10, T.i32())
17+
linalg.fill(5, x)
18+
linalg.fill_rng_2d(0.0, 10.0, 1, x)
19+
20+
x = tensor.empty(10, 10, T.i32())
21+
y = linalg.fill_rng_2d(0.0, 10.0, 1, x)
22+
z = linalg.fill(5, x)
23+
24+
correct = dedent(
25+
"""\
26+
module {
27+
%alloc = memref.alloc() : memref<10x10xi32>
28+
%c5_i32 = arith.constant 5 : i32
29+
linalg.fill ins(%c5_i32 : i32) outs(%alloc : memref<10x10xi32>)
30+
%cst = arith.constant 0.000000e+00 : f64
31+
%cst_0 = arith.constant 1.000000e+01 : f64
32+
%c1_i32 = arith.constant 1 : i32
33+
linalg.fill_rng_2d ins(%cst, %cst_0, %c1_i32 : f64, f64, i32) outs(%alloc : memref<10x10xi32>)
34+
%0 = tensor.empty() : tensor<10x10xi32>
35+
%cst_1 = arith.constant 0.000000e+00 : f64
36+
%cst_2 = arith.constant 1.000000e+01 : f64
37+
%c1_i32_3 = arith.constant 1 : i32
38+
%1 = linalg.fill_rng_2d ins(%cst_1, %cst_2, %c1_i32_3 : f64, f64, i32) outs(%0 : tensor<10x10xi32>) -> tensor<10x10xi32>
39+
%c5_i32_4 = arith.constant 5 : i32
40+
%2 = linalg.fill ins(%c5_i32_4 : i32) outs(%0 : tensor<10x10xi32>) -> tensor<10x10xi32>
41+
}
42+
"""
43+
)
44+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)