@@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
1070
1070
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1071
1071
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1072
1072
llvm::SmallVector<int64_t > outputShape (5 , ShapedType::kDynamicSize );
1073
- Conv2DOp ::Adaptor adaptor (operands.getValues (), attributes);
1073
+ Conv3DOp ::Adaptor adaptor (operands.getValues (), attributes);
1074
1074
1075
1075
int32_t inputWidth = ShapedType::kDynamicSize ;
1076
1076
int32_t inputHeight = ShapedType::kDynamicSize ;
@@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
1084
1084
ShapeAdaptor inputShape = operands.getShape (adaptor.getInput ());
1085
1085
if (inputShape.hasRank ()) {
1086
1086
outputShape[0 ] = inputShape.getDimSize (0 );
1087
- inputHeight = inputShape.getDimSize (1 );
1088
- inputWidth = inputShape.getDimSize (2 );
1089
- inputDepth = inputShape.getDimSize (3 );
1087
+ inputDepth = inputShape.getDimSize (1 );
1088
+ inputHeight = inputShape.getDimSize (2 );
1089
+ inputWidth = inputShape.getDimSize (3 );
1090
1090
}
1091
1091
1092
1092
// Weight shapes describes the filter width/height and the output channels.
1093
1093
ShapeAdaptor weightShape = operands.getShape (adaptor.getWeight ());
1094
1094
if (weightShape.hasRank ()) {
1095
1095
outputShape[4 ] = weightShape.getDimSize (0 );
1096
- weightHeight = weightShape.getDimSize (1 );
1097
- weightWidth = weightShape.getDimSize (2 );
1098
- weightDepth = weightShape.getDimSize (3 );
1096
+ weightDepth = weightShape.getDimSize (1 );
1097
+ weightHeight = weightShape.getDimSize (2 );
1098
+ weightWidth = weightShape.getDimSize (3 );
1099
1099
}
1100
1100
1101
1101
// Bias shape can describe the output channels.
1102
1102
ShapeAdaptor biasShape = operands.getShape (adaptor.getBias ());
1103
- if (biasShape.hasRank ()) {
1104
- outputShape[4 ] =
1105
- (outputShape[4 ] == -1 ) ? biasShape.getDimSize (0 ) : outputShape[4 ];
1103
+ if (biasShape.hasRank () && ShapedType::isDynamic (outputShape[4 ])) {
1104
+ outputShape[4 ] = biasShape.getDimSize (0 );
1106
1105
}
1107
1106
1108
1107
llvm::SmallVector<int64_t > dilation;
1109
- llvm::SmallVector<int64_t > padding ;
1108
+ llvm::SmallVector<int64_t > pad ;
1110
1109
llvm::SmallVector<int64_t > stride;
1111
1110
1112
1111
getI64Values (adaptor.getDilation (), dilation);
1113
- getI64Values (adaptor.getPad (), padding );
1112
+ getI64Values (adaptor.getPad (), pad );
1114
1113
getI64Values (adaptor.getStride (), stride);
1115
1114
1116
- if (!ShapedType::isDynamic (inputHeight ) &&
1117
- !ShapedType::isDynamic (weightHeight )) {
1118
- int32_t inputSize = inputHeight + padding [0 ] + padding [1 ];
1119
- int32_t filterSize = (weightHeight - 1 ) * dilation[0 ] + 1 ;
1115
+ if (!ShapedType::isDynamic (inputDepth ) &&
1116
+ !ShapedType::isDynamic (weightDepth )) {
1117
+ int32_t inputSize = inputDepth + pad [0 ] + pad [1 ];
1118
+ int32_t filterSize = (weightDepth - 1 ) * dilation[0 ] + 1 ;
1120
1119
int32_t unstridedResult = inputSize - filterSize + 1 ;
1121
1120
outputShape[1 ] = (unstridedResult - 1 ) / stride[0 ] + 1 ;
1122
1121
}
1123
1122
1124
- if (!ShapedType::isDynamic (inputWidth ) &&
1125
- !ShapedType::isDynamic (weightWidth )) {
1126
- int32_t inputSize = inputWidth + padding [2 ] + padding [3 ];
1127
- int32_t filterSize = (weightWidth - 1 ) * dilation[1 ] + 1 ;
1123
+ if (!ShapedType::isDynamic (inputHeight ) &&
1124
+ !ShapedType::isDynamic (weightHeight )) {
1125
+ int32_t inputSize = inputHeight + pad [2 ] + pad [3 ];
1126
+ int32_t filterSize = (weightHeight - 1 ) * dilation[1 ] + 1 ;
1128
1127
int32_t unstridedResult = inputSize - filterSize + 1 ;
1129
1128
outputShape[2 ] = (unstridedResult - 1 ) / stride[1 ] + 1 ;
1130
1129
}
1131
1130
1132
- if (!ShapedType::isDynamic (inputDepth ) &&
1133
- !ShapedType::isDynamic (weightDepth )) {
1134
- int32_t inputSize = inputDepth + padding [4 ] + padding [5 ];
1135
- int32_t filterSize = (weightDepth - 1 ) * dilation[2 ] + 1 ;
1131
+ if (!ShapedType::isDynamic (inputWidth ) &&
1132
+ !ShapedType::isDynamic (weightWidth )) {
1133
+ int32_t inputSize = inputWidth + pad [4 ] + pad [5 ];
1134
+ int32_t filterSize = (weightWidth - 1 ) * dilation[2 ] + 1 ;
1136
1135
int32_t unstridedResult = inputSize - filterSize + 1 ;
1137
1136
outputShape[3 ] = (unstridedResult - 1 ) / stride[2 ] + 1 ;
1138
1137
}
0 commit comments