Skip to content

Commit 73dfc8a

Browse files
committed
Added missing change
1 parent e8d511c commit 73dfc8a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,8 +1704,8 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
17041704
Type newResultType = getTypeConverter()->convertType(op.getType());
17051705

17061706
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707-
Value zeroTensor = createZeroInitTensor(
1708-
rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType);
1707+
Value initTensor = createInitTensor(
1708+
rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType, NULL);
17091709

17101710
// Set up affine indexing maps:
17111711
// We create a 2D loop iteration space. For the lhs, we use the first index
@@ -1725,9 +1725,9 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
17251725
Value outerProd =
17261726
rewriter
17271727
.create<linalg::GenericOp>(
1728-
loc, zeroTensor.getType(),
1728+
loc, initTensor.getType(),
17291729
/*inputs=*/ValueRange{lhsDim, rhsDim},
1730-
/*outputs=*/zeroTensor,
1730+
/*outputs=*/initTensor,
17311731
/*indexingMaps=*/
17321732
SmallVector<AffineMap, 3>{mapLhs, mapRhs, mapOut},
17331733
/*iteratortType=*/iteratorTypes,

0 commit comments

Comments
 (0)