@@ -210,15 +210,26 @@ template <typename T>
210210static LogicalResult verifyConvOp (T op) {
211211 // All TOSA conv ops have an input() and weight().
212212 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
213- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
213+
214+ RankedTensorType weightType;
215+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter ().getType ());
217+ else
218+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
214219
215220 // Must be ranked tensor types
216221 if (!inputType) {
217222 op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
218223 return failure ();
219224 }
220225 if (!weightType) {
221- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
226+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227+ op.emitOpError (" expect a ranked tensor for filter, got " )
228+ << op.getFilter ();
229+ } else {
230+ op.emitOpError (" expect a ranked tensor for weight, got " )
231+ << op.getWeight ();
232+ }
222233 return failure ();
223234 }
224235
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271282 return success ();
272283}
273284
285+ template <typename T>
286+ static LogicalResult verifyConvOpModes (T op) {
287+ auto inputEType =
288+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
289+
290+ if (auto quantType =
291+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292+ inputEType = quantType.getStorageType ();
293+
294+ auto accType = op.getAccType ();
295+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
296+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
297+
298+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
299+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
300+
301+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3 ()) &&
302+ !accType.isF16 ())
303+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
304+
305+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
306+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
307+
308+ if (inputEType.isBF16 () && !accType.isF32 ())
309+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
310+
311+ if (inputEType.isF32 () && !accType.isF32 ())
312+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
313+
314+ return success ();
315+ }
316+
274317LogicalResult tosa::ArgMaxOp::verify () {
275318 // Ensure output is of 32-bit integer
276319 const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368411 Type outputType, Value input, Value weight,
369412 Value bias, DenseI64ArrayAttr pad,
370413 DenseI64ArrayAttr stride,
371- DenseI64ArrayAttr dilation) {
414+ DenseI64ArrayAttr dilation,
415+ TypeAttr accType) {
372416
373417 result.addOperands ({input, weight, bias});
374418 result.addAttribute (" pad" , pad);
375419 result.addAttribute (" stride" , stride);
376420 result.addAttribute (" dilation" , dilation);
421+ result.addAttribute (" acc_type" , accType);
377422
378423 auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379424 if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390435static void buildTransConvOpWithQuantInfo (
391436 OpBuilder &builder, OperationState &result, Type outputType, Value input,
392437 Value weight, Value bias, DenseI64ArrayAttr outpad,
393- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394439 result.addOperands ({input, weight, bias});
395440 result.addAttribute (" out_pad" , outpad);
396441 result.addAttribute (" stride" , stride);
397442 result.addAttribute (" out_shape" , outputShape);
443+ result.addAttribute (" acc_type" , accType);
398444 auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399445
400446 if (quantAttr) {
@@ -1599,7 +1645,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
15991645 return success ();
16001646}
16011647
1602- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1648+ LogicalResult Conv2DOp::verify () {
1649+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1650+ return failure ();
1651+ return success ();
1652+ }
16031653
16041654LogicalResult Conv3DOp::inferReturnTypeComponents (
16051655 MLIRContext *context, ::std::optional<Location> location,
@@ -1671,7 +1721,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
16711721 return success ();
16721722}
16731723
1674- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1724+ LogicalResult Conv3DOp::verify () {
1725+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1726+ return failure ();
1727+ return success ();
1728+ }
16751729
16761730LogicalResult AvgPool2dOp::inferReturnTypeComponents (
16771731 MLIRContext *context, ::std::optional<Location> location,
@@ -1766,7 +1820,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
17661820 return success ();
17671821}
17681822
1769- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1823+ LogicalResult DepthwiseConv2DOp::verify () {
1824+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1825+ return failure ();
1826+ return success ();
1827+ }
17701828
17711829LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
17721830 MLIRContext *context, ::std::optional<Location> location,
@@ -1832,6 +1890,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
18321890 return success ();
18331891}
18341892
1893+ LogicalResult TransposeConv2DOp::verify () {
1894+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1895+ return failure ();
1896+ return success ();
1897+ }
1898+
18351899LogicalResult IfOp::inferReturnTypeComponents (
18361900 MLIRContext *context, ::std::optional<Location> location,
18371901 IfOp::Adaptor adaptor,
0 commit comments