Skip to content

Commit ca85d0d

Browse files
committed
Rewrote solution via decomposition
-Co-author: @ivanamitreski
1 parent 52a91a9 commit ca85d0d

File tree

4 files changed

+56
-76
lines changed

4 files changed

+56
-76
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,80 +1759,6 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
17591759

17601760
} // namespace
17611761

1762-
namespace {
1763-
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1764-
public:
1765-
using OpConversionPattern::OpConversionPattern;
1766-
LogicalResult
1767-
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1768-
ConversionPatternRewriter &rewriter) const override {
1769-
1770-
Location loc = op->getLoc();
1771-
Value lhs = adaptor.getSelf();
1772-
Value rhs = adaptor.getVec2();
1773-
1774-
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1775-
return failure();
1776-
}
1777-
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
1778-
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
1779-
1780-
if (!lhsType || !rhsType)
1781-
return rewriter.notifyMatchFailure(op,
1782-
"outer: expected ranked tensor types");
1783-
if (lhsType.getRank() != 1 || rhsType.getRank() != 1)
1784-
return rewriter.notifyMatchFailure(
1785-
op, "outer: expected 1D tensors for outer op lowering");
1786-
1787-
Value lhsDim = getDimOp(rewriter, loc, lhs, 0);
1788-
Value rhsDim = getDimOp(rewriter, loc, rhs, 0);
1789-
Type elementType = lhsType.getElementType();
1790-
Type newResultType = getTypeConverter()->convertType(op.getType());
1791-
1792-
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1793-
SmallVector<OpFoldResult> resultShape =
1794-
getAsOpFoldResult(ValueRange{lhsDim, rhsDim});
1795-
Value initTensor =
1796-
rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
1797-
1798-
// Set up affine indexing maps:
1799-
// We create a 2D loop iteration space. For the lhs, we use the first index
1800-
// (i), for the rhs, the second index (j), and for the result, both (i, j).
1801-
AffineMap mapLhs =
1802-
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(0)},
1803-
rewriter.getContext());
1804-
AffineMap mapRhs =
1805-
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)},
1806-
rewriter.getContext());
1807-
AffineMap mapOut =
1808-
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
1809-
1810-
SmallVector<utils::IteratorType, 2> iteratorTypes = {
1811-
utils::IteratorType::parallel, utils::IteratorType::parallel};
1812-
1813-
Value outerProd =
1814-
rewriter
1815-
.create<linalg::GenericOp>(
1816-
loc, initTensor.getType(),
1817-
/*inputs=*/ValueRange{lhsDim, rhsDim},
1818-
/*outputs=*/initTensor,
1819-
/*indexingMaps=*/
1820-
SmallVector<AffineMap, 3>{mapLhs, mapRhs, mapOut},
1821-
/*iteratortType=*/iteratorTypes,
1822-
[&](OpBuilder &b, Location loc, ValueRange args) {
1823-
Value lhsElem = args[0];
1824-
Value rhsElem = args[1];
1825-
Value mult = b.create<arith::MulFOp>(loc, lhsElem, rhsElem);
1826-
b.create<linalg::YieldOp>(loc, mult);
1827-
})
1828-
.getResult(0);
1829-
1830-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1831-
return success();
1832-
}
1833-
};
1834-
} // namespace
1835-
18361762
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
18371763
TypeConverter &typeConverter, RewritePatternSet &patterns,
18381764
ConversionTarget &target) {
@@ -1849,6 +1775,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
18491775
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
18501776
target.addIllegalOp<AtenFftRfftOp>();
18511777
patterns.add<ConvertAtenFftRfftOp>(typeConverter, context);
1852-
target.addIllegalOp<AtenOuterOp>();
1853-
patterns.add<ConvertAtenOuterOp>(typeConverter, context);
18541778
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,6 +1894,59 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern<AtenAtleast1dOp> {
18941894
};
18951895
} // namespace
18961896

1897+
// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul'
1898+
1899+
namespace {
1900+
class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
1901+
public:
1902+
using OpRewritePattern::OpRewritePattern;
1903+
LogicalResult matchAndRewrite(AtenOuterOp op,
1904+
PatternRewriter &rewriter) const override {
1905+
1906+
Location loc = op.getLoc();
1907+
Value input = op.getSelf();
1908+
Value vec2 = op.getVec2();
1909+
Type opType = op.getType();
1910+
1911+
auto inputType = cast<BaseTensorType>(input.getType());
1912+
auto vec2Type = cast<BaseTensorType>(vec2.getType());
1913+
1914+
// Check if both tensors are 1-dimensional
1915+
SmallVector<int64_t> inputShape(inputType.getSizes());
1916+
SmallVector<int64_t> vec2Shape(vec2Type.getSizes());
1917+
1918+
if (inputShape.size() == 1 && vec2Shape.size() == 1) {
1919+
1920+
Value one = rewriter.create<Torch::ConstantIntOp>(
1921+
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
1922+
SmallVector<int64_t, 2> inputMatrixShape = {inputShape[0], 1};
1923+
Type inputMatrixType = inputType.getWithSizesAndDtype(
1924+
inputMatrixShape, inputType.getOptionalDtype());
1925+
1926+
Value inputMatrix =
1927+
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1928+
1929+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1930+
loc, rewriter.getI64IntegerAttr(0));
1931+
SmallVector<int64_t, 2> vec2MatrixShape = {1, vec2Shape[0]};
1932+
Type vec2MatrixType = vec2Type.getWithSizesAndDtype(
1933+
vec2MatrixShape, vec2Type.getOptionalDtype());
1934+
1935+
Value vec2Matrix =
1936+
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1937+
1938+
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, opType, inputMatrix,
1939+
vec2Matrix);
1940+
return success();
1941+
} else {
1942+
return failure();
1943+
}
1944+
1945+
return success();
1946+
}
1947+
};
1948+
} // namespace
1949+
18971950
namespace {
18981951
// Decompose aten.atleast_2d into: aten.reshape. See
18991952
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604
@@ -11603,6 +11656,7 @@ class DecomposeComplexOpsPass
1160311656
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
1160411657
addPatternIfTargetOpIsIllegal<
1160511658
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
11659+
addPatternIfTargetOpIsIllegal<DecomposeAtenOuterOp>(patterns);
1160611660
addPatternIfTargetOpIsIllegal<
1160711661
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
1160811662
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
379379
target.addIllegalOp<AtenSoftshrinkOp>();
380380
target.addIllegalOp<AtenEmptyLikeOp>();
381381
target.addIllegalOp<AtenOnesLikeOp>();
382+
target.addIllegalOp<AtenOuterOp>();
382383
target.addIllegalOp<AtenZerosLikeOp>();
383384
target.addIllegalOp<AtenStackOp>();
384385
target.addIllegalOp<AtenHstackOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3815,6 +3815,7 @@
38153815
}
38163816

38173817
ONNX_TOSA_XFAIL_SET = {
3818+
"AtenOuter_basic",
38183819
"AtenFftRfft2DLastDim_basic",
38193820
"AtenFftRfft2DMiddleDim_basic",
38203821
"AtenNonzero1DDynamicModule_basic",

0 commit comments

Comments
 (0)