Skip to content

Commit 6fefb77

Browse files
committed
Rewrote solution via decomposition
-Co-author: @ivanamitreski
1 parent 0e88737 commit 6fefb77

File tree

4 files changed

+56
-76
lines changed

4 files changed

+56
-76
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,80 +1673,6 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
16731673

16741674
} // namespace
16751675

1676-
namespace {
1677-
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1678-
public:
1679-
using OpConversionPattern::OpConversionPattern;
1680-
LogicalResult
1681-
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1682-
ConversionPatternRewriter &rewriter) const override {
1683-
1684-
Location loc = op->getLoc();
1685-
Value lhs = adaptor.getSelf();
1686-
Value rhs = adaptor.getVec2();
1687-
1688-
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1689-
return failure();
1690-
}
1691-
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
1692-
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
1693-
1694-
if (!lhsType || !rhsType)
1695-
return rewriter.notifyMatchFailure(op,
1696-
"outer: expected ranked tensor types");
1697-
if (lhsType.getRank() != 1 || rhsType.getRank() != 1)
1698-
return rewriter.notifyMatchFailure(
1699-
op, "outer: expected 1D tensors for outer op lowering");
1700-
1701-
Value lhsDim = getDimOp(rewriter, loc, lhs, 0);
1702-
Value rhsDim = getDimOp(rewriter, loc, rhs, 0);
1703-
Type elementType = lhsType.getElementType();
1704-
Type newResultType = getTypeConverter()->convertType(op.getType());
1705-
1706-
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707-
SmallVector<OpFoldResult> resultShape =
1708-
getAsOpFoldResult(ValueRange{lhsDim, rhsDim});
1709-
Value initTensor =
1710-
rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
1711-
1712-
// Set up affine indexing maps:
1713-
// We create a 2D loop iteration space. For the lhs, we use the first index
1714-
// (i), for the rhs, the second index (j), and for the result, both (i, j).
1715-
AffineMap mapLhs =
1716-
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(0)},
1717-
rewriter.getContext());
1718-
AffineMap mapRhs =
1719-
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)},
1720-
rewriter.getContext());
1721-
AffineMap mapOut =
1722-
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
1723-
1724-
SmallVector<utils::IteratorType, 2> iteratorTypes = {
1725-
utils::IteratorType::parallel, utils::IteratorType::parallel};
1726-
1727-
Value outerProd =
1728-
rewriter
1729-
.create<linalg::GenericOp>(
1730-
loc, initTensor.getType(),
1731-
/*inputs=*/ValueRange{lhsDim, rhsDim},
1732-
/*outputs=*/initTensor,
1733-
/*indexingMaps=*/
1734-
SmallVector<AffineMap, 3>{mapLhs, mapRhs, mapOut},
1735-
/*iteratortType=*/iteratorTypes,
1736-
[&](OpBuilder &b, Location loc, ValueRange args) {
1737-
Value lhsElem = args[0];
1738-
Value rhsElem = args[1];
1739-
Value mult = b.create<arith::MulFOp>(loc, lhsElem, rhsElem);
1740-
b.create<linalg::YieldOp>(loc, mult);
1741-
})
1742-
.getResult(0);
1743-
1744-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1745-
return success();
1746-
}
1747-
};
1748-
} // namespace
1749-
17501676
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
17511677
TypeConverter &typeConverter, RewritePatternSet &patterns,
17521678
ConversionTarget &target) {
@@ -1763,6 +1689,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
17631689
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
17641690
target.addIllegalOp<AtenFftRfftOp>();
17651691
patterns.add<ConvertAtenFftRfftOp>(typeConverter, context);
1766-
target.addIllegalOp<AtenOuterOp>();
1767-
patterns.add<ConvertAtenOuterOp>(typeConverter, context);
17681692
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,6 +1894,59 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern<AtenAtleast1dOp> {
18941894
};
18951895
} // namespace
18961896

1897+
// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul'
1898+
1899+
namespace {
1900+
class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
1901+
public:
1902+
using OpRewritePattern::OpRewritePattern;
1903+
LogicalResult matchAndRewrite(AtenOuterOp op,
1904+
PatternRewriter &rewriter) const override {
1905+
1906+
Location loc = op.getLoc();
1907+
Value input = op.getSelf();
1908+
Value vec2 = op.getVec2();
1909+
Type opType = op.getType();
1910+
1911+
auto inputType = cast<BaseTensorType>(input.getType());
1912+
auto vec2Type = cast<BaseTensorType>(vec2.getType());
1913+
1914+
// Check if both tensors are 1-dimensional
1915+
SmallVector<int64_t> inputShape(inputType.getSizes());
1916+
SmallVector<int64_t> vec2Shape(vec2Type.getSizes());
1917+
1918+
if (inputShape.size() == 1 && vec2Shape.size() == 1) {
1919+
1920+
Value one = rewriter.create<Torch::ConstantIntOp>(
1921+
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
1922+
SmallVector<int64_t, 2> inputMatrixShape = {inputShape[0], 1};
1923+
Type inputMatrixType = inputType.getWithSizesAndDtype(
1924+
inputMatrixShape, inputType.getOptionalDtype());
1925+
1926+
Value inputMatrix =
1927+
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1928+
1929+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1930+
loc, rewriter.getI64IntegerAttr(0));
1931+
SmallVector<int64_t, 2> vec2MatrixShape = {1, vec2Shape[0]};
1932+
Type vec2MatrixType = vec2Type.getWithSizesAndDtype(
1933+
vec2MatrixShape, vec2Type.getOptionalDtype());
1934+
1935+
Value vec2Matrix =
1936+
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1937+
1938+
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, opType, inputMatrix,
1939+
vec2Matrix);
1940+
return success();
1941+
} else {
1942+
return failure();
1943+
}
1944+
1945+
return success();
1946+
}
1947+
};
1948+
} // namespace
1949+
18971950
namespace {
18981951
// Decompose aten.atleast_2d into: aten.reshape. See
18991952
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604
@@ -11591,6 +11644,7 @@ class DecomposeComplexOpsPass
1159111644
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
1159211645
addPatternIfTargetOpIsIllegal<
1159311646
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
11647+
addPatternIfTargetOpIsIllegal<DecomposeAtenOuterOp>(patterns);
1159411648
addPatternIfTargetOpIsIllegal<
1159511649
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
1159611650
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
379379
target.addIllegalOp<AtenSoftshrinkOp>();
380380
target.addIllegalOp<AtenEmptyLikeOp>();
381381
target.addIllegalOp<AtenOnesLikeOp>();
382+
target.addIllegalOp<AtenOuterOp>();
382383
target.addIllegalOp<AtenZerosLikeOp>();
383384
target.addIllegalOp<AtenStackOp>();
384385
target.addIllegalOp<AtenHstackOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3810,6 +3810,7 @@
38103810
}
38113811

38123812
ONNX_TOSA_XFAIL_SET = {
3813+
"AtenOuter_basic",
38133814
"AtenFftRfft2DLastDim_basic",
38143815
"AtenFftRfft2DMiddleDim_basic",
38153816
"AtenNonzero1DDynamicModule_basic",

0 commit comments

Comments
 (0)