Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
from .._linalg_enum_gen import _iteratortypeenum

# These are the ground truth functions defined as:
# ```
Expand Down Expand Up @@ -58,6 +59,7 @@

from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
from ...extras.meta import region_op


def transpose(
Expand Down Expand Up @@ -102,3 +104,46 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op


@register_attribute_builder("IteratorTypeArrayAttr")
def _IteratorTypeArrayAttr(x, context):
return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])


# The underscore is needed here so that there's no collision with opdsl generation.
class GenericOp_(GenericOp):
def __init__(
self,
inputs,
outputs,
indexing_maps,
iterator_types,
*,
doc=None,
library_call=None,
loc=None,
ip=None,
):
result_types = []
if isinstance(outputs[0].type, RankedTensorType):
result_types = [o.type for o in outputs]

super().__init__(
result_types,
inputs,
outputs,
indexing_maps,
iterator_types,
doc=doc,
library_call=library_call,
loc=loc,
ip=ip,
)
element_types = [i.type.element_type for i in inputs] + [
o.type.element_type for o in outputs
]
self.regions[0].blocks.append(*element_types)


generic = region_op(GenericOp_, terminator=YieldOp)
97 changes: 96 additions & 1 deletion mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.dialects import arith, builtin, func, linalg, tensor
from mlir.dialects import arith, func, linalg, tensor, memref
from mlir.dialects.linalg.opdsl.lang import *
from mlir.ir import *

Expand Down Expand Up @@ -84,6 +84,7 @@ def named_form(lhs, rhs):

print(module)


# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
Expand Down Expand Up @@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])

print(module)


# CHECK-LABEL: TEST: testGenericOp
@run
def testGenericOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
memref_t = MemRefType.get([10, 10], f32)
with InsertionPoint(module.body):
id_map_1 = AffineMap.get_identity(2)
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
# CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
x = tensor.empty((16, 16), f32)
y = tensor.empty((16, 16), f32)

# CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
# CHECK: ^bb0(%in: f32, %out: f32):
# CHECK: linalg.yield %in : f32
# CHECK: } -> tensor<16x16xf32>
@linalg.generic(
[x],
[y],
[id_map_1, id_map_1],
[linalg.IteratorType.parallel, linalg.IteratorType.parallel],
)
def f(a, b):
assert isinstance(a, Value)
assert isinstance(a.type, F32Type)
assert isinstance(b, Value)
assert isinstance(b.type, F32Type)
return a

assert isinstance(f, Value)
assert isinstance(f.type, RankedTensorType)

# CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
z = tensor.empty((16, 16, 16), f32)

minor_id = AffineMap.get_minor_identity(3, 2)
id_map_2 = AffineMap.get_identity(3)

# CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
# CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
# CHECK: linalg.yield %in, %out : f32, f32
# CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
@linalg.generic(
[x],
[z, z],
[minor_id, id_map_2, id_map_2],
[
linalg.IteratorType.parallel,
linalg.IteratorType.parallel,
linalg.IteratorType.parallel,
],
)
def g(a, b, c):
assert isinstance(a, Value)
assert isinstance(a.type, F32Type)
assert isinstance(b, Value)
assert isinstance(b.type, F32Type)
assert isinstance(c, Value)
assert isinstance(c.type, F32Type)
return a, b

assert isinstance(g, OpResultList)
assert len(g) == 2
assert isinstance(g[0].type, RankedTensorType)
assert isinstance(g[1].type, RankedTensorType)

# CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
# CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
xx = memref.alloc(memref_t, [], [])
yy = memref.alloc(memref_t, [], [])

# CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
# CHECK: ^bb0(%in: f32, %out: f32):
# CHECK: linalg.yield %in : f32
# CHECK: }
@linalg.generic(
[xx],
[yy],
[id_map_1, id_map_1],
[linalg.IteratorType.parallel, linalg.IteratorType.parallel],
)
def f(a, b):
assert isinstance(a, Value)
assert isinstance(a.type, F32Type)
assert isinstance(b, Value)
assert isinstance(b.type, F32Type)
return a

module.operation.verify()
print(module)
Loading