Skip to content

Commit 67ce816

Browse files
nodlabscathyzhyi
authored andcommitted
lowered addcmul and addcdiv to linalg
1 parent 8d8d2c2 commit 67ce816

File tree

6 files changed

+124
-0
lines changed

6 files changed

+124
-0
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,3 +684,40 @@ def forward(self, a, b, c):
684684
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
685685
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
686686

687+
class AddCMulModule(torch.nn.Module):
688+
def __init__(self):
689+
super().__init__()
690+
691+
@export
692+
@annotate_args([
693+
None,
694+
([-1, -1], torch.float32, True),
695+
([-1, -1], torch.float32, True),
696+
([-1, -1], torch.float32, True),
697+
])
698+
699+
def forward(self, input, tensor1, tensor2):
700+
return torch.addcmul(input, tensor1, tensor2, value=1.0)
701+
702+
@register_test_case(module_factory=lambda: AddCMulModule())
703+
def AddCMulModule_basic(module, tu: TestUtils):
704+
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
705+
706+
class AddCDivModule(torch.nn.Module):
707+
def __init__(self):
708+
super().__init__()
709+
710+
@export
711+
@annotate_args([
712+
None,
713+
([-1, -1], torch.float32, True),
714+
([-1, -1], torch.float32, True),
715+
([-1, -1], torch.float32, True),
716+
])
717+
718+
def forward(self, input, tensor1, tensor2):
719+
return torch.addcdiv(input, tensor1, tensor2, value=1.0)
720+
721+
@register_test_case(module_factory=lambda: AddCDivModule())
722+
def AddCDivModule_basic(module, tu: TestUtils):
723+
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@
3030
"ElementwiseLogModule_basic",
3131
"TanhBackward_basic",
3232
"ReturnThreeTensorFloat32_basic",
33+
"AddCMulModule_basic",
34+
"AddCDivModule_basic",
3335
}

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2895,3 +2895,37 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
28952895
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
28962896
}
28972897

2898+
def Torch_AtenAddCMulOp : Torch_Op<"aten.addcmul", [
2899+
AllowsTypeRefinement,
2900+
HasValueSemantics
2901+
]> {
2902+
let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
2903+
let arguments = (ins
2904+
AnyTorchTensorType:$self,
2905+
AnyTorchTensorType:$tensor1,
2906+
AnyTorchTensorType:$tensor2,
2907+
AnyTorchScalarType:$value
2908+
);
2909+
let results = (outs
2910+
AnyTorchTensorType:$result
2911+
);
2912+
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
2913+
}
2914+
2915+
def Torch_AtenAddCDivOp : Torch_Op<"aten.addcdiv", [
2916+
AllowsTypeRefinement,
2917+
HasValueSemantics
2918+
]> {
2919+
let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
2920+
let arguments = (ins
2921+
AnyTorchTensorType:$self,
2922+
AnyTorchTensorType:$tensor1,
2923+
AnyTorchTensorType:$tensor2,
2924+
AnyTorchScalarType:$value
2925+
);
2926+
let results = (outs
2927+
AnyTorchTensorType:$result
2928+
);
2929+
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
2930+
}
2931+

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,26 @@ class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
375375
};
376376
} // namespace
377377

378+
namespace {
379+
template<typename OpTy, typename T1T2Op>
380+
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
381+
using OpRewritePattern<OpTy>::OpRewritePattern;
382+
LogicalResult matchAndRewrite(OpTy op,
383+
PatternRewriter &rewriter) const override {
384+
Location loc = op.getLoc();
385+
Value input = op.self();
386+
Value tensor1 = op.tensor1();
387+
Value tensor2 = op.tensor2();
388+
Value value = op.value();
389+
390+
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
391+
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
392+
value);
393+
return success();
394+
}
395+
};
396+
} // namespace
397+
378398
namespace {
379399
class DecomposeComplexOpsPass
380400
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -408,6 +428,10 @@ class DecomposeComplexOpsPass
408428
// Make aten.matmul legal if the following condition is satisfied.
409429
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
410430
});
431+
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCMulOp, AtenMulTensorOp>>(context);
432+
target.addIllegalOp<AtenAddCMulOp>();
433+
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCDivOp, AtenDivTensorOp>>(context);
434+
target.addIllegalOp<AtenAddCDivOp>();
411435
if (failed(applyPartialConversion(getOperation(), target,
412436
std::move(patterns)))) {
413437
return signalPassFailure();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
422422
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
423423
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
424424
return visitNumToTensorOp(numToTensorOp);
425+
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
426+
return visitAtenAddCLikeOp(op, operands);
425427
}
426428

427429
// Otherwise, this is an unknown operation. Just mark all results as
@@ -535,6 +537,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
535537
ChangeResult
536538
visitAtenSoftmaxLikeOp(OpTy op,
537539
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
540+
541+
ChangeResult
542+
visitAtenAddCLikeOp(Operation *op,
543+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
538544
};
539545
} // namespace
540546

@@ -1376,6 +1382,25 @@ ChangeResult TypeAnalyzer::visitAtenMatmulOp(
13761382
return getLatticeElement(op->getResult(0)).join(knowledge);
13771383
}
13781384

1385+
ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
1386+
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1387+
auto knowledge =
1388+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1389+
auto self = operands[0]->getValue();
1390+
auto tensor1 = operands[1]->getValue();
1391+
auto tensor2 = operands[2]->getValue();
1392+
if (tensor1.hasSizes && tensor2.hasSizes && self.hasSizes) {
1393+
knowledge.hasSizes = true;
1394+
knowledge.sizes.resize(
1395+
std::max(self.sizes.size(),
1396+
std::max(tensor1.sizes.size(), tensor2.sizes.size())),
1397+
kUnknownSize);
1398+
}
1399+
knowledge.dtype =
1400+
getPromotedResultType(getContext(), {&self, &tensor1, &tensor2});
1401+
return getLatticeElement(op->getResult(0)).join(knowledge);
1402+
}
1403+
13791404
// -----------------------------------------------------------------------------
13801405
// Transforms.
13811406
// -----------------------------------------------------------------------------

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ def emit_with_mutating_variants(key, **kwargs):
471471
emit_with_mutating_variants(key)
472472
# Elementwise tensor compute ops that don't have the standard mutating
473473
# variants.
474+
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
475+
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
474476
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
475477
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
476478
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")

0 commit comments

Comments
 (0)