Skip to content

Commit eb04f32

Browse files
tatwaichongrsuderman
authored andcommitted
[tosa] Add legalization for conv3d
Update the existing implementation to match TOSA spec. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D133062
1 parent 3f96581 commit eb04f32

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
10701070
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
10711071
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
10721072
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1073-
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1073+
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
10741074

10751075
int32_t inputWidth = ShapedType::kDynamicSize;
10761076
int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
10841084
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
10851085
if (inputShape.hasRank()) {
10861086
outputShape[0] = inputShape.getDimSize(0);
1087-
inputHeight = inputShape.getDimSize(1);
1088-
inputWidth = inputShape.getDimSize(2);
1089-
inputDepth = inputShape.getDimSize(3);
1087+
inputDepth = inputShape.getDimSize(1);
1088+
inputHeight = inputShape.getDimSize(2);
1089+
inputWidth = inputShape.getDimSize(3);
10901090
}
10911091

10921092
// Weight shapes describes the filter width/height and the output channels.
10931093
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
10941094
if (weightShape.hasRank()) {
10951095
outputShape[4] = weightShape.getDimSize(0);
1096-
weightHeight = weightShape.getDimSize(1);
1097-
weightWidth = weightShape.getDimSize(2);
1098-
weightDepth = weightShape.getDimSize(3);
1096+
weightDepth = weightShape.getDimSize(1);
1097+
weightHeight = weightShape.getDimSize(2);
1098+
weightWidth = weightShape.getDimSize(3);
10991099
}
11001100

11011101
// Bias shape can describe the output channels.
11021102
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1103-
if (biasShape.hasRank()) {
1104-
outputShape[4] =
1105-
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1103+
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1104+
outputShape[4] = biasShape.getDimSize(0);
11061105
}
11071106

11081107
llvm::SmallVector<int64_t> dilation;
1109-
llvm::SmallVector<int64_t> padding;
1108+
llvm::SmallVector<int64_t> pad;
11101109
llvm::SmallVector<int64_t> stride;
11111110

11121111
getI64Values(adaptor.getDilation(), dilation);
1113-
getI64Values(adaptor.getPad(), padding);
1112+
getI64Values(adaptor.getPad(), pad);
11141113
getI64Values(adaptor.getStride(), stride);
11151114

1116-
if (!ShapedType::isDynamic(inputHeight) &&
1117-
!ShapedType::isDynamic(weightHeight)) {
1118-
int32_t inputSize = inputHeight + padding[0] + padding[1];
1119-
int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1115+
if (!ShapedType::isDynamic(inputDepth) &&
1116+
!ShapedType::isDynamic(weightDepth)) {
1117+
int32_t inputSize = inputDepth + pad[0] + pad[1];
1118+
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
11201119
int32_t unstridedResult = inputSize - filterSize + 1;
11211120
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
11221121
}
11231122

1124-
if (!ShapedType::isDynamic(inputWidth) &&
1125-
!ShapedType::isDynamic(weightWidth)) {
1126-
int32_t inputSize = inputWidth + padding[2] + padding[3];
1127-
int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1123+
if (!ShapedType::isDynamic(inputHeight) &&
1124+
!ShapedType::isDynamic(weightHeight)) {
1125+
int32_t inputSize = inputHeight + pad[2] + pad[3];
1126+
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
11281127
int32_t unstridedResult = inputSize - filterSize + 1;
11291128
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
11301129
}
11311130

1132-
if (!ShapedType::isDynamic(inputDepth) &&
1133-
!ShapedType::isDynamic(weightDepth)) {
1134-
int32_t inputSize = inputDepth + padding[4] + padding[5];
1135-
int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1131+
if (!ShapedType::isDynamic(inputWidth) &&
1132+
!ShapedType::isDynamic(weightWidth)) {
1133+
int32_t inputSize = inputWidth + pad[4] + pad[5];
1134+
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
11361135
int32_t unstridedResult = inputSize - filterSize + 1;
11371136
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
11381137
}

0 commit comments

Comments
 (0)