@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
278278
279279template <typename T>
280280static LogicalResult verifyConvOp (T op) {
281- // All TOSA conv ops have an input and weight arguments which must be ranked
282- // tensors.
283- auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
284- if (!inputType) {
285- op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
286- return failure ();
287- }
288-
289- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
290- if (!weightType) {
291- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
292- return failure ();
293- }
281+ const auto inputType = llvm::dyn_cast<TensorType>(op.getInput ().getType ());
282+ const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight ().getType ());
294283
295284 auto inputEType = inputType.getElementType ();
296285 auto weightEType = weightType.getElementType ();
@@ -3063,14 +3052,6 @@ LogicalResult TransposeConv2DOp::verify() {
30633052 return emitOpError (" expect all stride values to be >= 1, got [" )
30643053 << strides << " ]" ;
30653054
3066- const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3067-
3068- const auto outputType =
3069- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
3070-
3071- const auto weightType =
3072- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
3073-
30743055 const auto checkPadAgainstKernelDim =
30753056 [this ](int64_t pad_value, int64_t kernel_dim_size,
30763057 llvm::StringRef pad_name,
@@ -3084,69 +3065,77 @@ LogicalResult TransposeConv2DOp::verify() {
30843065 };
30853066
30863067 const llvm::ArrayRef<int64_t > padding = getOutPad ();
3087-
30883068 const int64_t outPadTop = padding[0 ];
30893069 const int64_t outPadBottom = padding[1 ];
3070+ const int64_t outPadLeft = padding[2 ];
3071+ const int64_t outPadRight = padding[3 ];
30903072
3091- const int64_t kernelHeight = weightType.getDimSize (1 );
3092-
3093- if (!ShapedType::isDynamic (kernelHeight)) {
3094- if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight, " out_pad_top" ,
3095- " KH" )))
3096- return failure ();
3097-
3098- if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3099- " out_pad_bottom" , " KH" )))
3100- return failure ();
3101- }
3073+ const auto weightType =
3074+ llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
31023075
3103- const int64_t kernelWidth = weightType.getDimSize (2 );
3076+ if (weightType) {
3077+ const int64_t kernelHeight = weightType.getDimSize (1 );
3078+ if (!ShapedType::isDynamic (kernelHeight)) {
3079+ if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight,
3080+ " out_pad_top" , " KH" )))
3081+ return failure ();
31043082
3105- const int64_t outPadLeft = padding[2 ];
3106- const int64_t outPadRight = padding[3 ];
3083+ if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3084+ " out_pad_bottom" , " KH" )))
3085+ return failure ();
3086+ }
31073087
3108- if (!ShapedType::isDynamic (kernelWidth)) {
3109- if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth, " out_pad_left" ,
3110- " KW" )))
3111- return failure ();
3088+ const int64_t kernelWidth = weightType.getDimSize (2 );
3089+ if (!ShapedType::isDynamic (kernelWidth)) {
3090+ if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth,
3091+ " out_pad_left" , " KW" )))
3092+ return failure ();
31123093
3113- if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3114- " out_pad_right" , " KW" )))
3115- return failure ();
3094+ if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3095+ " out_pad_right" , " KW" )))
3096+ return failure ();
3097+ }
31163098 }
31173099
31183100 // Rest of the checks depend on the output type being a RankedTensorType
3101+ const auto outputType =
3102+ llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
31193103 if (!outputType)
31203104 return success ();
31213105
3122- const int64_t inputHeight = inputType.getDimSize (1 );
3123- const int64_t outputHeight = outputType.getDimSize (1 );
3124-
3125- if (!ShapedType::isDynamic (inputHeight) &&
3126- !ShapedType::isDynamic (outputHeight)) {
3127- if (outputHeight !=
3128- (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3129- return emitOpError (
3130- " dimension mismatch: expected OH == (IH - 1) * stride_y "
3131- " + out_pad_top + out_pad_bottom + KH, but got " )
3132- << outputHeight << " != (" << inputHeight << " - 1) * " << strideY
3133- << " + " << outPadTop << " + " << outPadBottom << " + "
3134- << kernelHeight;
3135- }
3106+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3107+ if (inputType && weightType) {
3108+ const int64_t inputHeight = inputType.getDimSize (1 );
3109+ const int64_t kernelHeight = weightType.getDimSize (1 );
3110+ const int64_t outputHeight = outputType.getDimSize (1 );
3111+
3112+ if (!ShapedType::isDynamic (inputHeight) &&
3113+ !ShapedType::isDynamic (outputHeight)) {
3114+ if (outputHeight !=
3115+ (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3116+ return emitOpError (
3117+ " dimension mismatch: expected OH == (IH - 1) * stride_y "
3118+ " + out_pad_top + out_pad_bottom + KH, but got " )
3119+ << outputHeight << " != (" << inputHeight << " - 1) * "
3120+ << strideY << " + " << outPadTop << " + " << outPadBottom
3121+ << " + " << kernelHeight;
3122+ }
31363123
3137- const int64_t inputWidth = inputType.getDimSize (2 );
3138- const int64_t outputWidth = outputType.getDimSize (2 );
3124+ const int64_t inputWidth = inputType.getDimSize (2 );
3125+ const int64_t kernelWidth = weightType.getDimSize (2 );
3126+ const int64_t outputWidth = outputType.getDimSize (2 );
31393127
3140- if (!ShapedType::isDynamic (inputWidth) &&
3141- !ShapedType::isDynamic (outputWidth)) {
3142- if (outputWidth !=
3143- (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3144- return emitOpError (
3145- " dimension mismatch: expected OW == (IW - 1) * stride_x "
3146- " + out_pad_left + out_pad_right + KW, but got " )
3147- << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3148- << " + " << outPadLeft << " + " << outPadRight << " + "
3149- << kernelWidth;
3128+ if (!ShapedType::isDynamic (inputWidth) &&
3129+ !ShapedType::isDynamic (outputWidth)) {
3130+ if (outputWidth !=
3131+ (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3132+ return emitOpError (
3133+ " dimension mismatch: expected OW == (IW - 1) * stride_x "
3134+ " + out_pad_left + out_pad_right + KW, but got " )
3135+ << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3136+ << " + " << outPadLeft << " + " << outPadRight << " + "
3137+ << kernelWidth;
3138+ }
31503139 }
31513140
31523141 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
0 commit comments