@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
278278
279279template <typename T>
280280static LogicalResult verifyConvOp (T op) {
281- // All TOSA conv ops have an input and weight arguments which must be ranked
282- // tensors.
283- auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
284- if (!inputType) {
285- op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
286- return failure ();
287- }
288-
289- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
290- if (!weightType) {
291- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
292- return failure ();
293- }
281+ const auto inputType = llvm::dyn_cast<TensorType>(op.getInput ().getType ());
282+ const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight ().getType ());
294283
295284 auto inputEType = inputType.getElementType ();
296285 auto weightEType = weightType.getElementType ();
@@ -2998,14 +2987,6 @@ LogicalResult TransposeConv2DOp::verify() {
29982987 return emitOpError (" expect all stride values to be >= 1, got [" )
29992988 << strides << " ]" ;
30002989
3001- const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3002-
3003- const auto outputType =
3004- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
3005-
3006- const auto weightType =
3007- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
3008-
30092990 const auto checkPadAgainstKernelDim =
30102991 [this ](int64_t pad_value, int64_t kernel_dim_size,
30112992 llvm::StringRef pad_name,
@@ -3019,69 +3000,77 @@ LogicalResult TransposeConv2DOp::verify() {
30193000 };
30203001
30213002 const llvm::ArrayRef<int64_t > padding = getOutPad ();
3022-
30233003 const int64_t outPadTop = padding[0 ];
30243004 const int64_t outPadBottom = padding[1 ];
3005+ const int64_t outPadLeft = padding[2 ];
3006+ const int64_t outPadRight = padding[3 ];
30253007
3026- const int64_t kernelHeight = weightType.getDimSize (1 );
3027-
3028- if (!ShapedType::isDynamic (kernelHeight)) {
3029- if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight, " out_pad_top" ,
3030- " KH" )))
3031- return failure ();
3032-
3033- if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3034- " out_pad_bottom" , " KH" )))
3035- return failure ();
3036- }
3008+ const auto weightType =
3009+ llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
30373010
3038- const int64_t kernelWidth = weightType.getDimSize (2 );
3011+ if (weightType) {
3012+ const int64_t kernelHeight = weightType.getDimSize (1 );
3013+ if (!ShapedType::isDynamic (kernelHeight)) {
3014+ if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight,
3015+ " out_pad_top" , " KH" )))
3016+ return failure ();
30393017
3040- const int64_t outPadLeft = padding[2 ];
3041- const int64_t outPadRight = padding[3 ];
3018+ if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3019+ " out_pad_bottom" , " KH" )))
3020+ return failure ();
3021+ }
30423022
3043- if (!ShapedType::isDynamic (kernelWidth)) {
3044- if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth, " out_pad_left" ,
3045- " KW" )))
3046- return failure ();
3023+ const int64_t kernelWidth = weightType.getDimSize (2 );
3024+ if (!ShapedType::isDynamic (kernelWidth)) {
3025+ if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth,
3026+ " out_pad_left" , " KW" )))
3027+ return failure ();
30473028
3048- if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3049- " out_pad_right" , " KW" )))
3050- return failure ();
3029+ if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3030+ " out_pad_right" , " KW" )))
3031+ return failure ();
3032+ }
30513033 }
30523034
30533035 // Rest of the checks depend on the output type being a RankedTensorType
3036+ const auto outputType =
3037+ llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
30543038 if (!outputType)
30553039 return success ();
30563040
3057- const int64_t inputHeight = inputType.getDimSize (1 );
3058- const int64_t outputHeight = outputType.getDimSize (1 );
3059-
3060- if (!ShapedType::isDynamic (inputHeight) &&
3061- !ShapedType::isDynamic (outputHeight)) {
3062- if (outputHeight !=
3063- (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3064- return emitOpError (
3065- " dimension mismatch: expected OH == (IH - 1) * stride_y "
3066- " + out_pad_top + out_pad_bottom + KH, but got " )
3067- << outputHeight << " != (" << inputHeight << " - 1) * " << strideY
3068- << " + " << outPadTop << " + " << outPadBottom << " + "
3069- << kernelHeight;
3070- }
3041+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3042+ if (inputType && weightType) {
3043+ const int64_t inputHeight = inputType.getDimSize (1 );
3044+ const int64_t kernelHeight = weightType.getDimSize (1 );
3045+ const int64_t outputHeight = outputType.getDimSize (1 );
3046+
3047+ if (!ShapedType::isDynamic (inputHeight) &&
3048+ !ShapedType::isDynamic (outputHeight)) {
3049+ if (outputHeight !=
3050+ (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3051+ return emitOpError (
3052+ " dimension mismatch: expected OH == (IH - 1) * stride_y "
3053+ " + out_pad_top + out_pad_bottom + KH, but got " )
3054+ << outputHeight << " != (" << inputHeight << " - 1) * "
3055+ << strideY << " + " << outPadTop << " + " << outPadBottom
3056+ << " + " << kernelHeight;
3057+ }
30713058
3072- const int64_t inputWidth = inputType.getDimSize (2 );
3073- const int64_t outputWidth = outputType.getDimSize (2 );
3059+ const int64_t inputWidth = inputType.getDimSize (2 );
3060+ const int64_t kernelWidth = weightType.getDimSize (2 );
3061+ const int64_t outputWidth = outputType.getDimSize (2 );
30743062
3075- if (!ShapedType::isDynamic (inputWidth) &&
3076- !ShapedType::isDynamic (outputWidth)) {
3077- if (outputWidth !=
3078- (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3079- return emitOpError (
3080- " dimension mismatch: expected OW == (IW - 1) * stride_x "
3081- " + out_pad_left + out_pad_right + KW, but got " )
3082- << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3083- << " + " << outPadLeft << " + " << outPadRight << " + "
3084- << kernelWidth;
3063+ if (!ShapedType::isDynamic (inputWidth) &&
3064+ !ShapedType::isDynamic (outputWidth)) {
3065+ if (outputWidth !=
3066+ (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3067+ return emitOpError (
3068+ " dimension mismatch: expected OW == (IW - 1) * stride_x "
3069+ " + out_pad_left + out_pad_right + KW, but got " )
3070+ << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3071+ << " + " << outPadLeft << " + " << outPadRight << " + "
3072+ << kernelWidth;
3073+ }
30853074 }
30863075
30873076 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
0 commit comments