@@ -210,15 +210,26 @@ template <typename T>
210210static LogicalResult verifyConvOp (T op) {
211211 // All TOSA conv ops have an input() and weight().
212212 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
213- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
213+
214+ RankedTensorType weightType;
215+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter ().getType ());
217+ else
218+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
214219
215220 // Must be ranked tensor types
216221 if (!inputType) {
217222 op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
218223 return failure ();
219224 }
220225 if (!weightType) {
221- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
226+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227+ op.emitOpError (" expect a ranked tensor for filter, got " )
228+ << op.getFilter ();
229+ } else {
230+ op.emitOpError (" expect a ranked tensor for weight, got " )
231+ << op.getWeight ();
232+ }
222233 return failure ();
223234 }
224235
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271282 return success ();
272283}
273284
285+ template <typename T>
286+ static LogicalResult verifyConvOpModes (T op) {
287+ auto inputEType =
288+ llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
289+
290+ if (auto quantType =
291+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292+ inputEType = quantType.getStorageType ();
293+
294+ auto accType = op.getAccType ();
295+ if (inputEType.isInteger (8 ) && !accType.isInteger (32 ))
296+ return op.emitOpError (" accumulator type for i8 tensor is not i32" );
297+
298+ if (inputEType.isInteger (16 ) && !accType.isInteger (48 ))
299+ return op.emitOpError (" accumulator type for i16 tensor is not i48" );
300+
301+ if ((inputEType.isFloat8E5M2 () || inputEType.isFloat8E4M3 ()) &&
302+ !accType.isF16 ())
303+ return op.emitOpError (" accumulator type for f8 tensor is not f16" );
304+
305+ if (inputEType.isF16 () && !(accType.isF16 () || accType.isF32 ()))
306+ return op.emitOpError (" accumulator type for f16 tensor is not f16/f32" );
307+
308+ if (inputEType.isBF16 () && !accType.isF32 ())
309+ return op.emitOpError (" accumulator type for bf16 tensor is not f32" );
310+
311+ if (inputEType.isF32 () && !accType.isF32 ())
312+ return op.emitOpError (" accumulator type for f32 tensor is not f32" );
313+
314+ return success ();
315+ }
316+
274317LogicalResult tosa::ArgMaxOp::verify () {
275318 // Ensure output is of 32-bit integer
276319 const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368411 Type outputType, Value input, Value weight,
369412 Value bias, DenseI64ArrayAttr pad,
370413 DenseI64ArrayAttr stride,
371- DenseI64ArrayAttr dilation) {
414+ DenseI64ArrayAttr dilation,
415+ TypeAttr accType) {
372416
373417 result.addOperands ({input, weight, bias});
374418 result.addAttribute (" pad" , pad);
375419 result.addAttribute (" stride" , stride);
376420 result.addAttribute (" dilation" , dilation);
421+ result.addAttribute (" acc_type" , accType);
377422
378423 auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
379424 if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390435static void buildTransConvOpWithQuantInfo (
391436 OpBuilder &builder, OperationState &result, Type outputType, Value input,
392437 Value weight, Value bias, DenseI64ArrayAttr outpad,
393- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType ) {
394439 result.addOperands ({input, weight, bias});
395440 result.addAttribute (" out_pad" , outpad);
396441 result.addAttribute (" stride" , stride);
397442 result.addAttribute (" out_shape" , outputShape);
443+ result.addAttribute (" acc_type" , accType);
398444 auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
399445
400446 if (quantAttr) {
@@ -1678,7 +1724,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
16781724 return success ();
16791725}
16801726
1681- LogicalResult Conv2DOp::verify () { return verifyConvOp (*this ); }
1727+ LogicalResult Conv2DOp::verify () {
1728+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1729+ return failure ();
1730+ return success ();
1731+ }
16821732
16831733LogicalResult Conv3DOp::inferReturnTypeComponents (
16841734 MLIRContext *context, ::std::optional<Location> location,
@@ -1750,7 +1800,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
17501800 return success ();
17511801}
17521802
1753- LogicalResult Conv3DOp::verify () { return verifyConvOp (*this ); }
1803+ LogicalResult Conv3DOp::verify () {
1804+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1805+ return failure ();
1806+ return success ();
1807+ }
17541808
17551809LogicalResult AvgPool2dOp::inferReturnTypeComponents (
17561810 MLIRContext *context, ::std::optional<Location> location,
@@ -1845,7 +1899,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
18451899 return success ();
18461900}
18471901
1848- LogicalResult DepthwiseConv2DOp::verify () { return verifyConvOp (*this ); }
1902+ LogicalResult DepthwiseConv2DOp::verify () {
1903+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1904+ return failure ();
1905+ return success ();
1906+ }
18491907
18501908LogicalResult TransposeConv2DOp::inferReturnTypeComponents (
18511909 MLIRContext *context, ::std::optional<Location> location,
@@ -1911,6 +1969,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
19111969 return success ();
19121970}
19131971
1972+ LogicalResult TransposeConv2DOp::verify () {
1973+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
1974+ return failure ();
1975+ return success ();
1976+ }
1977+
19141978LogicalResult IfOp::inferReturnTypeComponents (
19151979 MLIRContext *context, ::std::optional<Location> location,
19161980 IfOp::Adaptor adaptor,
0 commit comments