Skip to content

Commit ae86855

Browse files
authored
fix matmul (#105)
1 parent 014a5c1 commit ae86855

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

mlir/extras/dialects/ext/func.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
InsertionPoint,
1818
OpView,
1919
Operation,
20+
OpResultList,
2021
Type,
2122
TypeAttr,
2223
Value,
@@ -255,7 +256,7 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
255256
def grab_results(*args):
256257
nonlocal return_types
257258
results = self.body_builder(*args)
258-
if isinstance(results, (tuple, list)):
259+
if isinstance(results, (tuple, list, OpResultList)):
259260
return_types.extend([r.type for r in results])
260261
elif results is not None:
261262
return_types.append(results.type)

mlir/extras/dialects/ext/linalg.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from . import arith
22
from ...util import get_user_code_loc
33
from ....dialects import linalg
4+
45
# noinspection PyUnresolvedReferences
56
from ....dialects.linalg import *
67
from ....extras import types as T
8+
from .... import ir
79

810

911
def abs(I, O, *, loc=None, ip=None):
@@ -297,10 +299,57 @@ def log(I, O, *, loc=None, ip=None):
297299
return linalg.log(I, loc=loc, ip=ip, outs=[O])
298300

299301

302+
@linalg.linalg_structured_op
303+
def _matmul_generic(
304+
A=linalg.TensorDef(linalg.T1, linalg.S.M, linalg.S.K),
305+
B=linalg.TensorDef(linalg.T2, linalg.S.K, linalg.S.N),
306+
C=linalg.TensorDef(linalg.U, linalg.S.M, linalg.S.N, output=True),
307+
cast=linalg.TypeFnAttrDef(default=linalg.TypeFn.cast_signed),
308+
):
309+
linalg.domain(linalg.D.m, linalg.D.n, linalg.D.k)
310+
linalg.implements(linalg.ContractionOpInterface)
311+
C[linalg.D.m, linalg.D.n] += cast(linalg.U, A[linalg.D.m, linalg.D.k]) * cast(
312+
linalg.U, B[linalg.D.k, linalg.D.n]
313+
)
314+
315+
316+
_matmul_generic.op_name = "matmul"
317+
318+
300319
def matmul(A, B, C, *, loc=None, ip=None):
301320
if loc is None:
302321
loc = get_user_code_loc()
303-
return linalg.matmul(A, B, loc=loc, ip=ip, outs=[C])
322+
323+
op_configs = linalg.LinalgOpConfig.from_linalg_op_def(
324+
_matmul_generic.op_def, context=ir.Context.current
325+
)
326+
op_config = op_configs[0]
327+
(
328+
_all_arg_defs,
329+
_in_arg_defs,
330+
_out_arg_defs,
331+
_outs,
332+
result_types,
333+
_type_mapping,
334+
indexing_maps_attr,
335+
_iterator_types_attr,
336+
_index_attrs,
337+
_fn_attr_mapping,
338+
_block_arg_types,
339+
) = linalg.opdsl.lang.emitter.prepare_common_structured_op(
340+
op_config.structured_op, A, B, outs=[C], loc=loc, ip=ip
341+
)
342+
named_op = linalg.MatmulOp(
343+
result_types,
344+
inputs=[A, B],
345+
outputs=[C],
346+
indexing_maps=indexing_maps_attr,
347+
cast=linalg.TypeFn.cast_signed,
348+
loc=loc,
349+
ip=ip,
350+
)
351+
linalg.fill_builtin_region(named_op.operation)
352+
return named_op.results
304353

305354

306355
def matmul_transpose_a(A, B, C, *, loc=None, ip=None):

0 commit comments

Comments
 (0)