@@ -149,7 +149,8 @@ def __init__(
149149generic = region_op (GenericOp_ , terminator = YieldOp )
150150
151151
152- def matmul (
152+ def create_op (
153+ op_type ,
153154 * ins : Union [Operation , OpView , Value ],
154155 outs : Sequence [Union [Operation , OpView , Value ]],
155156 indexing_maps : Optional [Sequence [AffineMapAttr ]] = None ,
@@ -161,7 +162,7 @@ def matmul(
161162 init = _get_op_result_or_value (outs [0 ])
162163 result_types = [init .type ] if isinstance (init .type , RankedTensorType ) else []
163164
164- op = MatmulOp (
165+ op = op_type (
165166 result_tensors = result_types ,
166167 inputs = ins ,
167168 outputs = [init ],
@@ -172,45 +173,32 @@ def matmul(
172173 return op
173174
174175
175- def contract (
176+ def matmul (
176177 * ins : Union [Operation , OpView , Value ],
177178 outs : Sequence [Union [Operation , OpView , Value ]],
178- indexing_maps : Sequence [AffineMapAttr ],
179+ indexing_maps : Optional [ Sequence [AffineMapAttr ]] = None ,
179180 cast : Optional [Union [TypeFn , Attribute ]] = None ,
180181):
181- ins = [_get_op_result_or_value (input ) for input in ins ]
182- if len (outs ) > 1 :
183- raise ValueError (f"{ outs = } must have length 1." )
184- init = _get_op_result_or_value (outs [0 ])
185- result_types = [init .type ] if isinstance (init .type , RankedTensorType ) else []
186-
187- op = ContractOp (
188- result_tensors = result_types ,
189- inputs = ins ,
190- outputs = [init ],
191- indexing_maps = indexing_maps ,
192- cast = cast ,
193- )
194- fill_builtin_region (op .operation )
195- return op
182+ return create_op (MatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast )
196183
197184
198185def batch_matmul (
199186 * ins : Union [Operation , OpView , Value ],
200187 outs : Sequence [Union [Operation , OpView , Value ]],
201188 indexing_maps : Optional [Sequence [AffineMapAttr ]] = None ,
189+ cast : Optional [Union [TypeFn , Attribute ]] = None ,
202190):
203- ins = [_get_op_result_or_value (input ) for input in ins ]
204- if len (outs ) > 1 :
205- raise ValueError (f"{ outs = } must have length 1." )
206- init = _get_op_result_or_value (outs [0 ])
207- result_types = [init .type ] if isinstance (init .type , RankedTensorType ) else []
191+ return create_op (
192+ BatchMatmulOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
193+ )
208194
209- op = BatchMatmulOp (
210- result_tensors = result_types ,
211- inputs = ins ,
212- outputs = [init ],
213- indexing_maps = indexing_maps ,
195+
196+ def contract (
197+ * ins : Union [Operation , OpView , Value ],
198+ outs : Sequence [Union [Operation , OpView , Value ]],
199+ indexing_maps : Sequence [AffineMapAttr ],
200+ cast : Optional [Union [TypeFn , Attribute ]] = None ,
201+ ):
202+ return create_op (
203+ ContractOp , * ins , outs = outs , indexing_maps = indexing_maps , cast = cast
214204 )
215- fill_builtin_region (op .operation )
216- return op
0 commit comments