Skip to content

Commit f97b948

Browse files
Merge pull request #406 from Xilinx/chaitany.convtranspose_fixing_onnxpad_for_weights
fixing the onnx.pad constant used for padding the weights
1 parent a2be298 commit f97b948

File tree

2 files changed

+146
-3
lines changed

2 files changed

+146
-3
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,9 +1549,10 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
15491549
}
15501550
auto onnxPadsValueConstant =
15511551
getONNXConstOpFromVector(rewriter, loc, weightsPadValue);
1552-
RankedTensorType scalarTy = RankedTensorType::get({}, elementType);
1553-
Value onnxPaddingConstantZero = create.onnx.constant(
1554-
DenseElementsAttr::get(scalarTy, rewriter.getZeroAttr(elementType)));
1552+
auto weightsElementType = weightsType.getElementType();
1553+
RankedTensorType scalarTy = RankedTensorType::get({}, weightsElementType);
1554+
Value onnxPaddingConstantZero = create.onnx.constant(DenseElementsAttr::get(
1555+
scalarTy, rewriter.getZeroAttr(weightsElementType)));
15551556

15561557
auto onnxAxisValueConstantNone = create.onnx.none();
15571558
auto wts_shape = weightsType.getShape();

0 commit comments

Comments
 (0)