@@ -1911,37 +1911,41 @@ class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
1911
1911
auto inputType = cast<BaseTensorType>(input.getType ());
1912
1912
auto vec2Type = cast<BaseTensorType>(vec2.getType ());
1913
1913
1914
+ // Check if tensors not empty
1915
+ if (!inputType.hasSizes () || !vec2Type.hasSizes ()) {
1916
+ return rewriter.notifyMatchFailure (
1917
+ op, " Inputs must be ranked tensors for aten.outer" );
1918
+ }
1919
+
1914
1920
// Check if both tensors are 1-dimensional
1915
1921
SmallVector<int64_t > inputShape (inputType.getSizes ());
1916
1922
SmallVector<int64_t > vec2Shape (vec2Type.getSizes ());
1917
1923
1918
- if (inputShape.size () == 1 && vec2Shape.size () == 1 ) {
1924
+ if (inputShape.size () != 1 || vec2Shape.size () != 1 ) {
1925
+ return rewriter.notifyMatchFailure (
1926
+ op, " Inputs must be 1-dimensional vectors for aten.outer" );
1927
+ }
1919
1928
1920
- Value one = rewriter.create <Torch::ConstantIntOp>(
1921
- loc, rewriter.getI64IntegerAttr (1 )); // Dimension index
1922
- SmallVector<int64_t , 2 > inputMatrixShape = {inputShape[0 ], 1 };
1923
- Type inputMatrixType = inputType.getWithSizesAndDtype (
1924
- inputMatrixShape, inputType.getOptionalDtype ());
1929
+ Value one = rewriter.create <Torch::ConstantIntOp>(
1930
+ loc, rewriter.getI64IntegerAttr (1 )); // Dimension index
1931
+ SmallVector<int64_t , 2 > inputMatrixShape = {inputShape[0 ], 1 };
1932
+ Type inputMatrixType = inputType.getWithSizesAndDtype (
1933
+ inputMatrixShape, inputType.getOptionalDtype ());
1925
1934
1926
- Value inputMatrix =
1927
- rewriter.create <AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1935
+ Value inputMatrix =
1936
+ rewriter.create <AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1928
1937
1929
- Value zero = rewriter.create <Torch::ConstantIntOp>(
1930
- loc, rewriter.getI64IntegerAttr (0 ));
1931
- SmallVector<int64_t , 2 > vec2MatrixShape = {1 , vec2Shape[0 ]};
1932
- Type vec2MatrixType = vec2Type.getWithSizesAndDtype (
1933
- vec2MatrixShape, vec2Type.getOptionalDtype ());
1934
-
1935
- Value vec2Matrix =
1936
- rewriter.create <AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1938
+ Value zero = rewriter.create <Torch::ConstantIntOp>(
1939
+ loc, rewriter.getI64IntegerAttr (0 ));
1940
+ SmallVector<int64_t , 2 > vec2MatrixShape = {1 , vec2Shape[0 ]};
1941
+ Type vec2MatrixType = vec2Type.getWithSizesAndDtype (
1942
+ vec2MatrixShape, vec2Type.getOptionalDtype ());
1937
1943
1938
- rewriter.replaceOpWithNewOp <AtenMatmulOp>(op, opType, inputMatrix,
1939
- vec2Matrix);
1940
- return success ();
1941
- } else {
1942
- return failure ();
1943
- }
1944
+ Value vec2Matrix =
1945
+ rewriter.create <AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1944
1946
1947
+ rewriter.replaceOpWithNewOp <AtenMatmulOp>(op, opType, inputMatrix,
1948
+ vec2Matrix);
1945
1949
return success ();
1946
1950
}
1947
1951
};
0 commit comments