5858from .opdsl .ops .core_named_ops import *
5959
6060from ...ir import *
61- from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+ from .._ods_common import (
62+ get_op_result_or_value as _get_op_result_or_value ,
63+ get_op_result_or_op_results as _get_op_result_or_op_results ,
64+ _dispatch_mixed_values ,
65+ )
6266from ...extras .meta import region_op
6367
6468
@@ -149,7 +153,7 @@ def __init__(
149153generic = region_op (GenericOp_ , terminator = YieldOp )
150154
151155
152- def create_op (
156+ def _create_matmul_like_op (
153157 op_type ,
154158 * ins : Union [Operation , OpView , Value ],
155159 outs : Sequence [Union [Operation , OpView , Value ]],
@@ -179,7 +183,11 @@ def matmul(
179183 indexing_maps : Optional [Sequence [AffineMapAttr ]] = None ,
180184 cast : Optional [Union [TypeFn , Attribute ]] = None ,
181185):
182- return create_op (MatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast )
186+ return _get_op_result_or_op_results (
187+ _create_matmul_like_op (
188+ MatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
189+ )
190+ )
183191
184192
185193def batch_matmul (
@@ -188,8 +196,10 @@ def batch_matmul(
188196 indexing_maps : Optional [Sequence [AffineMapAttr ]] = None ,
189197 cast : Optional [Union [TypeFn , Attribute ]] = None ,
190198):
191- return create_op (
192- BatchMatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
199+ return _get_op_result_or_op_results (
200+ _create_matmul_like_op (
201+ BatchMatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
202+ )
193203 )
194204
195205
@@ -199,6 +209,72 @@ def contract(
199209 indexing_maps : Sequence [AffineMapAttr ],
200210 cast : Optional [Union [TypeFn , Attribute ]] = None ,
201211):
202- return create_op (
203- ContractOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
212+ return _get_op_result_or_op_results (
213+ _create_matmul_like_op (
214+ ContractOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
215+ )
216+ )
217+
218+
219+ def pack (
220+ source ,
221+ dest ,
222+ inner_dims_pos ,
223+ inner_tiles ,
224+ * ,
225+ padding_value = None ,
226+ outer_dims_perm = None ,
227+ loc = None ,
228+ ip = None ,
229+ ) -> ir .Value :
230+ (
231+ dynamic_inner_tiles ,
232+ # packed here means %1:2 packing (results packing)
233+ _inner_tiles ,
234+ static_inner_tiles ,
235+ ) = _dispatch_mixed_values (inner_tiles )
236+
237+ return _get_op_result_or_op_results (
238+ PackOp (
239+ source = source ,
240+ dest = dest ,
241+ inner_dims_pos = inner_dims_pos ,
242+ inner_tiles = dynamic_inner_tiles ,
243+ static_inner_tiles = static_inner_tiles ,
244+ padding_value = padding_value ,
245+ outer_dims_perm = outer_dims_perm ,
246+ loc = loc ,
247+ ip = ip ,
248+ )
249+ )
250+
251+
252+ def unpack (
253+ source ,
254+ dest ,
255+ inner_dims_pos ,
256+ inner_tiles ,
257+ * ,
258+ outer_dims_perm = None ,
259+ loc = None ,
260+ ip = None ,
261+ ) -> ir .Value :
262+ (
263+ dynamic_inner_tiles ,
264+ # packed here means %1:2 packing (results packing)
265+ _inner_tiles ,
266+ static_inner_tiles ,
267+ ) = _dispatch_mixed_values (inner_tiles )
268+
269+ return _get_op_result_or_op_results (
270+ UnPackOp (
271+ source = source ,
272+ dest = dest ,
273+ inner_dims_pos = inner_dims_pos ,
274+ inner_tiles = dynamic_inner_tiles ,
275+ static_inner_tiles = static_inner_tiles ,
276+ outer_dims_perm = outer_dims_perm ,
277+ loc = loc ,
278+ ip = ip ,
279+ )
204280 )
0 commit comments