Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit c6275ab

Browse files
authored
[mlir][python] fix linalg.pack/unpack (#127729)
This PR llvm/llvm-project#123902 broke python bindings for `tensor.pack`/`unpack`. This PR fixes that. It also 1. adds convenience wrappers for pack/unpack 2. cleans up matmul-like ops in the linalg bindings 3. fixes linalg docs missing pack/unpack
1 parent 94435d8 commit c6275ab

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

mlir/python/mlir/dialects/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
1313
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
14+
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
1415

1516
#endif

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
from .opdsl.ops.core_named_ops import *
5959

6060
from ...ir import *
61-
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+
from .._ods_common import (
62+
get_op_result_or_value as _get_op_result_or_value,
63+
get_op_result_or_op_results as _get_op_result_or_op_results,
64+
_dispatch_mixed_values,
65+
)
6266
from ...extras.meta import region_op
6367

6468

@@ -149,7 +153,7 @@ def __init__(
149153
generic = region_op(GenericOp_, terminator=YieldOp)
150154

151155

152-
def create_op(
156+
def _create_matmul_like_op(
153157
op_type,
154158
*ins: Union[Operation, OpView, Value],
155159
outs: Sequence[Union[Operation, OpView, Value]],
@@ -179,7 +183,11 @@ def matmul(
179183
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
180184
cast: Optional[Union[TypeFn, Attribute]] = None,
181185
):
182-
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
186+
return _get_op_result_or_op_results(
187+
_create_matmul_like_op(
188+
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
189+
)
190+
)
183191

184192

185193
def batch_matmul(
@@ -188,8 +196,10 @@ def batch_matmul(
188196
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
189197
cast: Optional[Union[TypeFn, Attribute]] = None,
190198
):
191-
return create_op(
192-
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
199+
return _get_op_result_or_op_results(
200+
_create_matmul_like_op(
201+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
202+
)
193203
)
194204

195205

@@ -199,6 +209,72 @@ def contract(
199209
indexing_maps: Sequence[AffineMapAttr],
200210
cast: Optional[Union[TypeFn, Attribute]] = None,
201211
):
202-
return create_op(
203-
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
212+
return _get_op_result_or_op_results(
213+
_create_matmul_like_op(
214+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
215+
)
216+
)
217+
218+
219+
def pack(
220+
source,
221+
dest,
222+
inner_dims_pos,
223+
inner_tiles,
224+
*,
225+
padding_value=None,
226+
outer_dims_perm=None,
227+
loc=None,
228+
ip=None,
229+
) -> ir.Value:
230+
(
231+
dynamic_inner_tiles,
232+
# packed here means %1:2 packing (results packing)
233+
_inner_tiles,
234+
static_inner_tiles,
235+
) = _dispatch_mixed_values(inner_tiles)
236+
237+
return _get_op_result_or_op_results(
238+
PackOp(
239+
source=source,
240+
dest=dest,
241+
inner_dims_pos=inner_dims_pos,
242+
inner_tiles=dynamic_inner_tiles,
243+
static_inner_tiles=static_inner_tiles,
244+
padding_value=padding_value,
245+
outer_dims_perm=outer_dims_perm,
246+
loc=loc,
247+
ip=ip,
248+
)
249+
)
250+
251+
252+
def unpack(
253+
source,
254+
dest,
255+
inner_dims_pos,
256+
inner_tiles,
257+
*,
258+
outer_dims_perm=None,
259+
loc=None,
260+
ip=None,
261+
) -> ir.Value:
262+
(
263+
dynamic_inner_tiles,
264+
# packed here means %1:2 packing (results packing)
265+
_inner_tiles,
266+
static_inner_tiles,
267+
) = _dispatch_mixed_values(inner_tiles)
268+
269+
return _get_op_result_or_op_results(
270+
UnPackOp(
271+
source=source,
272+
dest=dest,
273+
inner_dims_pos=inner_dims_pos,
274+
inner_tiles=dynamic_inner_tiles,
275+
static_inner_tiles=static_inner_tiles,
276+
outer_dims_perm=outer_dims_perm,
277+
loc=loc,
278+
ip=ip,
279+
)
204280
)

0 commit comments

Comments
 (0)