Skip to content

Commit 2debe15

Browse files
Update Conv2D creation to include accumulate type attribute.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent f92c587 commit 2debe15

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,7 +2370,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23702370
transposedInput, transformedWeight, bias,
23712371
rewriter.getDenseI64ArrayAttr(padding),
23722372
rewriter.getDenseI64ArrayAttr(stride),
2373-
rewriter.getDenseI64ArrayAttr(dilation))
2373+
rewriter.getDenseI64ArrayAttr(dilation),
2374+
TypeAttr::get(biasElemTy))
23742375
.getResult();
23752376
} else if (weightShape[1] == 1) {
23762377
// depthwise convolution
@@ -2381,7 +2382,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23812382
transposedInput, transformedWeight, bias,
23822383
rewriter.getDenseI64ArrayAttr(padding),
23832384
rewriter.getDenseI64ArrayAttr(stride),
2384-
rewriter.getDenseI64ArrayAttr(dilation))
2385+
rewriter.getDenseI64ArrayAttr(dilation),
2386+
TypeAttr::get(biasElemTy))
23852387
.getResult();
23862388
} else {
23872389
llvm_unreachable("Unhandled convolution type");

0 commit comments

Comments
 (0)