@@ -210,15 +210,26 @@ template <typename T>
210210static LogicalResult verifyConvOp (T op) {
211211 // All TOSA conv ops have an input() and weight().
212212 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
213- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
213+
214+ RankedTensorType weightType;
215+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter ().getType ());
217+ else
218+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
214219
215220 // Must be ranked tensor types
216221 if (!inputType) {
217222 op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
218223 return failure ();
219224 }
220225 if (!weightType) {
221- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
226+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227+ op.emitOpError (" expect a ranked tensor for filter, got " )
228+ << op.getFilter ();
229+ } else {
230+ op.emitOpError (" expect a ranked tensor for weight, got " )
231+ << op.getWeight ();
232+ }
222233 return failure ();
223234 }
224235
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271282 return success ();
272283}
273284
285+ template <typename T>
286+ static LogicalResult verifyConvOpModes (T op) {
287+ auto inputEType =
288+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
289+
290+ if (auto quantType =
291+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292+ inputEType = quantType.getStorageType ();
293+
294+ auto accType = op.getAccType ();
295+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
296+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
297+
298+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
299+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
300+
301+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3 ()) &&
302+ !accType.isF16 ())
303+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
304+
305+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
306+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
307+
308+ if (inputEType.isBF16 () && !accType.isF32 ())
309+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
310+
311+ if (inputEType.isF32 () && !accType.isF32 ())
312+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
313+
314+ return success ();
315+ }
316+
274317LogicalResult tosa::ArgMaxOp::verify () {
275318 // Ensure output is of 32-bit integer
276319 const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368411 Type outputType, Value input, Value weight,
369412 Value bias, DenseI64ArrayAttr pad,
370413 DenseI64ArrayAttr stride,
371- DenseI64ArrayAttr dilation) {
414+ DenseI64ArrayAttr dilation,
415+ TypeAttr accType) {
372416
373417 result.addOperands ({input, weight, bias});
374418 result.addAttribute (" pad" , pad);
375419 result.addAttribute (" stride" , stride);
376420 result.addAttribute (" dilation" , dilation);
421+ result.addAttribute (" acc_type" , accType);
377422
378423 auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379424 if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390435static void buildTransConvOpWithQuantInfo (
391436 OpBuilder &builder, OperationState &result, Type outputType, Value input,
392437 Value weight, Value bias, DenseI64ArrayAttr outpad,
393- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394439 result.addOperands ({input, weight, bias});
395440 result.addAttribute (" out_pad" , outpad);
396441 result.addAttribute (" stride" , stride);
397442 result.addAttribute (" out_shape" , outputShape);
443+ result.addAttribute (" acc_type" , accType);
398444 auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399445
400446 if (quantAttr) {
@@ -1595,7 +1641,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
15951641 return success ();
15961642}
15971643
1598- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1644+ LogicalResult Conv2DOp::verify () {
1645+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1646+ return failure ();
1647+ return success ();
1648+ }
15991649
16001650LogicalResult Conv3DOp::inferReturnTypeComponents (
16011651 MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1717,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
16671717 return success ();
16681718}
16691719
1670- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1720+ LogicalResult Conv3DOp::verify () {
1721+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1722+ return failure ();
1723+ return success ();
1724+ }
16711725
16721726LogicalResult AvgPool2dOp::inferReturnTypeComponents (
16731727 MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1816,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
17621816 return success ();
17631817}
17641818
1765- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1819+ LogicalResult DepthwiseConv2DOp::verify () {
1820+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1821+ return failure ();
1822+ return success ();
1823+ }
17661824
17671825LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
17681826 MLIRContext *context, ::std::optional<Location> location,
@@ -1828,6 +1886,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
18281886 return success ();
18291887}
18301888
1889+ LogicalResult TransposeConv2DOp::verify () {
1890+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1891+ return failure ();
1892+ return success ();
1893+ }
1894+
18311895LogicalResult IfOp::inferReturnTypeComponents (
18321896 MLIRContext *context, ::std::optional<Location> location,
18331897 IfOp::Adaptor adaptor,
0 commit comments