|
1 | 1 | from . import arith
|
2 | 2 | from ...util import get_user_code_loc
|
3 | 3 | from ....dialects import linalg
|
| 4 | + |
4 | 5 | # noinspection PyUnresolvedReferences
|
5 | 6 | from ....dialects.linalg import *
|
6 | 7 | from ....extras import types as T
|
| 8 | +from .... import ir |
7 | 9 |
|
8 | 10 |
|
9 | 11 | def abs(I, O, *, loc=None, ip=None):
|
@@ -297,10 +299,57 @@ def log(I, O, *, loc=None, ip=None):
|
297 | 299 | return linalg.log(I, loc=loc, ip=ip, outs=[O])
|
298 | 300 |
|
299 | 301 |
|
| 302 | +@linalg.linalg_structured_op |
| 303 | +def _matmul_generic( |
| 304 | + A=linalg.TensorDef(linalg.T1, linalg.S.M, linalg.S.K), |
| 305 | + B=linalg.TensorDef(linalg.T2, linalg.S.K, linalg.S.N), |
| 306 | + C=linalg.TensorDef(linalg.U, linalg.S.M, linalg.S.N, output=True), |
| 307 | + cast=linalg.TypeFnAttrDef(default=linalg.TypeFn.cast_signed), |
| 308 | +): |
| 309 | + linalg.domain(linalg.D.m, linalg.D.n, linalg.D.k) |
| 310 | + linalg.implements(linalg.ContractionOpInterface) |
| 311 | + C[linalg.D.m, linalg.D.n] += cast(linalg.U, A[linalg.D.m, linalg.D.k]) * cast( |
| 312 | + linalg.U, B[linalg.D.k, linalg.D.n] |
| 313 | + ) |
| 314 | + |
| 315 | + |
| 316 | +_matmul_generic.op_name = "matmul" |
| 317 | + |
| 318 | + |
300 | 319 | def matmul(A, B, C, *, loc=None, ip=None):
|
301 | 320 | if loc is None:
|
302 | 321 | loc = get_user_code_loc()
|
303 |
| - return linalg.matmul(A, B, loc=loc, ip=ip, outs=[C]) |
| 322 | + |
| 323 | + op_configs = linalg.LinalgOpConfig.from_linalg_op_def( |
| 324 | + _matmul_generic.op_def, context=ir.Context.current |
| 325 | + ) |
| 326 | + op_config = op_configs[0] |
| 327 | + ( |
| 328 | + _all_arg_defs, |
| 329 | + _in_arg_defs, |
| 330 | + _out_arg_defs, |
| 331 | + _outs, |
| 332 | + result_types, |
| 333 | + _type_mapping, |
| 334 | + indexing_maps_attr, |
| 335 | + _iterator_types_attr, |
| 336 | + _index_attrs, |
| 337 | + _fn_attr_mapping, |
| 338 | + _block_arg_types, |
| 339 | + ) = linalg.opdsl.lang.emitter.prepare_common_structured_op( |
| 340 | + op_config.structured_op, A, B, outs=[C], loc=loc, ip=ip |
| 341 | + ) |
| 342 | + named_op = linalg.MatmulOp( |
| 343 | + result_types, |
| 344 | + inputs=[A, B], |
| 345 | + outputs=[C], |
| 346 | + indexing_maps=indexing_maps_attr, |
| 347 | + cast=linalg.TypeFn.cast_signed, |
| 348 | + loc=loc, |
| 349 | + ip=ip, |
| 350 | + ) |
| 351 | + linalg.fill_builtin_region(named_op.operation) |
| 352 | + return named_op.results |
304 | 353 |
|
305 | 354 |
|
306 | 355 | def matmul_transpose_a(A, B, C, *, loc=None, ip=None):
|
|
0 commit comments