Skip to content

Commit 8f63eaa

Browse files
committed
-Refactored the linalg op python wrapper creation with a helper function.
1 parent 3be9234 commit 8f63eaa

File tree

3 files changed

+36
-39
lines changed

3 files changed

+36
-39
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -859,9 +859,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
859859
Variadic<AnyType>:$inputs,
860860
Variadic<AnyShaped>:$outputs,
861861
DefaultValuedOptionalAttr<
862-
AffineMapArrayAttr,
863-
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
864-
>:$indexing_maps
862+
AffineMapArrayAttr,
863+
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
864+
>:$indexing_maps,
865+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
865866
);
866867
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
867868
let regions = (region AnyRegion:$region);
@@ -887,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
887888
}]>,
888889
OpBuilder<
889890
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
890-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
891+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
891892
[{
892893
$_state.addOperands(operands);
894+
$_state.addAttribute("cast", cast);
893895
$_state.addAttributes(attributes);
894896
$_state.addTypes(resultTensorTypes);
895897
(void)$_state.addRegion(),

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
39513951
RegionBuilderHelper helper(b, block);
39523952
SmallVector<Value> yields;
39533953

3954+
TypeFn castVal = TypeFn::cast_signed;
3955+
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3956+
return attr.getName() == "cast";
3957+
});
3958+
if (castIter != attrs.end()) {
3959+
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3960+
castVal = attr.getValue();
3961+
}
3962+
39543963
auto toType = block.getArgument(2).getType();
3955-
Value castValA =
3956-
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
3957-
Value castValB =
3958-
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
3964+
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
3965+
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
39593966
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
39603967
Value addVal =
39613968
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def __init__(
149149
generic = region_op(GenericOp_, terminator=YieldOp)
150150

151151

152-
def matmul(
152+
def create_op(
153+
op_type,
153154
*ins: Union[Operation, OpView, Value],
154155
outs: Sequence[Union[Operation, OpView, Value]],
155156
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
@@ -161,7 +162,7 @@ def matmul(
161162
init = _get_op_result_or_value(outs[0])
162163
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
163164

164-
op = MatmulOp(
165+
op = op_type(
165166
result_tensors=result_types,
166167
inputs=ins,
167168
outputs=[init],
@@ -172,45 +173,32 @@ def matmul(
172173
return op
173174

174175

175-
def contract(
176+
def matmul(
176177
*ins: Union[Operation, OpView, Value],
177178
outs: Sequence[Union[Operation, OpView, Value]],
178-
indexing_maps: Sequence[AffineMapAttr],
179+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
179180
cast: Optional[Union[TypeFn, Attribute]] = None,
180181
):
181-
ins = [_get_op_result_or_value(input) for input in ins]
182-
if len(outs) > 1:
183-
raise ValueError(f"{outs=} must have length 1.")
184-
init = _get_op_result_or_value(outs[0])
185-
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
186-
187-
op = ContractOp(
188-
result_tensors=result_types,
189-
inputs=ins,
190-
outputs=[init],
191-
indexing_maps=indexing_maps,
192-
cast=cast,
193-
)
194-
fill_builtin_region(op.operation)
195-
return op
182+
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
196183

197184

198185
def batch_matmul(
199186
*ins: Union[Operation, OpView, Value],
200187
outs: Sequence[Union[Operation, OpView, Value]],
201188
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
189+
cast: Optional[Union[TypeFn, Attribute]] = None,
202190
):
203-
ins = [_get_op_result_or_value(input) for input in ins]
204-
if len(outs) > 1:
205-
raise ValueError(f"{outs=} must have length 1.")
206-
init = _get_op_result_or_value(outs[0])
207-
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
191+
return create_op(
192+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
193+
)
208194

209-
op = BatchMatmulOp(
210-
result_tensors=result_types,
211-
inputs=ins,
212-
outputs=[init],
213-
indexing_maps=indexing_maps,
195+
196+
def contract(
197+
*ins: Union[Operation, OpView, Value],
198+
outs: Sequence[Union[Operation, OpView, Value]],
199+
indexing_maps: Sequence[AffineMapAttr],
200+
cast: Optional[Union[TypeFn, Attribute]] = None,
201+
):
202+
return create_op(
203+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
214204
)
215-
fill_builtin_region(op.operation)
216-
return op

0 commit comments

Comments
 (0)