@@ -271,6 +271,55 @@ LogicalResult tosa::ConstOp::verify() {
271271 return success ();
272272}
273273
274+ template <typename T>
275+ static LogicalResult verifyConvOpModes (T op) {
276+ auto inputEType =
277+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
278+
279+ if (auto quantType =
280+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
281+ inputEType = quantType.getStorageType ();
282+
283+ auto accType = op.getAccType ();
284+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
285+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
286+
287+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
288+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
289+
290+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3FN ()) &&
291+ !accType.isF16 ())
292+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
293+
294+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
295+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
296+
297+ if (inputEType.isBF16 () && !accType.isF32 ())
298+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
299+
300+ if (inputEType.isF32 () && !accType.isF32 ())
301+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
302+
303+ auto resultEType =
304+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
305+
306+ if (auto quantType =
307+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
308+ resultEType = quantType.getStorageType ();
309+
310+ // check allowed input/result element types combinations
311+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
312+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
313+ (inputEType.isFloat8E5M2 () && resultEType.isF16 ()) ||
314+ (inputEType.isFloat8E4M3FN () && resultEType.isF16 ()) ||
315+ (inputEType.isF16 () && resultEType.isF16 ()) ||
316+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
317+ (inputEType.isF32 () && resultEType.isF32 ()))
318+ return success ();
319+
320+ return op.emitOpError (" input/output element types are incompatible." );
321+ }
322+
274323LogicalResult tosa::ArgMaxOp::verify () {
275324 // Ensure output is of 32-bit integer
276325 const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +417,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368417 Type outputType, Value input, Value weight,
369418 Value bias, DenseI64ArrayAttr pad,
370419 DenseI64ArrayAttr stride,
371- DenseI64ArrayAttr dilation) {
420+ DenseI64ArrayAttr dilation,
421+ TypeAttr accType) {
372422
373423 result.addOperands ({input, weight, bias});
374424 result.addAttribute (" pad" , pad);
375425 result.addAttribute (" stride" , stride);
376426 result.addAttribute (" dilation" , dilation);
427+ result.addAttribute (" acc_type" , accType);
377428
378429 auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379430 if (quantAttr) {
@@ -390,11 +441,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390441static void buildTransConvOpWithQuantInfo (
391442 OpBuilder &builder, OperationState &result, Type outputType, Value input,
392443 Value weight, Value bias, DenseI64ArrayAttr outpad,
393- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
444+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394445 result.addOperands ({input, weight, bias});
395446 result.addAttribute (" out_pad" , outpad);
396447 result.addAttribute (" stride" , stride);
397448 result.addAttribute (" out_shape" , outputShape);
449+ result.addAttribute (" acc_type" , accType);
398450 auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399451
400452 if (quantAttr) {
@@ -1595,7 +1647,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
15951647 return success ();
15961648}
15971649
1598- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1650+ LogicalResult Conv2DOp::verify () {
1651+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1652+ return failure ();
1653+ return success ();
1654+ }
15991655
16001656LogicalResult Conv3DOp::inferReturnTypeComponents (
16011657 MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1723,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
16671723 return success ();
16681724}
16691725
1670- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1726+ LogicalResult Conv3DOp::verify () {
1727+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1728+ return failure ();
1729+ return success ();
1730+ }
16711731
16721732LogicalResult AvgPool2dOp::inferReturnTypeComponents (
16731733 MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1822,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
17621822 return success ();
17631823}
17641824
1765- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1825+ LogicalResult DepthwiseConv2DOp::verify () {
1826+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1827+ return failure ();
1828+ return success ();
1829+ }
17661830
17671831LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
17681832 MLIRContext *context, ::std::optional<Location> location,
0 commit comments