@@ -34,7 +34,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
3434 const llvm::ArrayRef<int64_t > weightShape, Value &newInput,
3535 Value &newWeight, Value &bias, const int64_t groups,
3636 DenseI64ArrayAttr &pads, DenseI64ArrayAttr &strides,
37- DenseI64ArrayAttr &dilations) {
37+ DenseI64ArrayAttr &dilations, TypeAttr &accType ) {
3838 // Set up constants outside of loop
3939 const int64_t sizeOfSliceInput = weightShape[1 ];
4040 const int64_t sizeOfSliceKernel = weightShape[0 ] / groups;
@@ -65,7 +65,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
6565 mlir::cast<ShapedType>(resultType).getElementType ());
6666 Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
6767 op->getLoc (), newConvOutputType, newSliceInput, newSliceWeight,
68- newSliceBias, pads, strides, dilations);
68+ newSliceBias, pads, strides, dilations, accType );
6969 // Add value to vector
7070 sliceValues.push_back (tempConv2D);
7171 }
@@ -156,6 +156,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
156156
157157 DenseI64ArrayAttr newPads = rewriter.getDenseI64ArrayAttr (reorderedPads);
158158
159+ Type convType =
160+ (resultType.isF16 ()) ? rewriter.getF16Type () : rewriter.getF32Type ();
161+ TypeAttr accType = mlir::TypeAttr::get (convType);
162+
159163 // Handle group parameter by creating multiple convs
160164 const int64_t group = adaptor.getGroup ();
161165 Value conv2D = NULL ;
@@ -166,10 +170,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
166170
167171 conv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
168172 convOp->getLoc (), newConvOutputType, newInput, newWeight, bias,
169- newPads, strides, dilations);
173+ newPads, strides, dilations, accType );
170174 } else {
171175 auto inputChannels = inputType.getDimSize (1 );
172- auto outputChannels = resultType. cast <ShapedType>().getDimSize (1 );
176+ auto outputChannels = cast<ShapedType>(resultType ).getDimSize (1 );
173177 if (group == inputChannels && (outputChannels % inputChannels == 0 )) {
174178 // If the group == inputChannels and
175179 // outputChannels == inputChannels * integerNumber,
@@ -185,19 +189,19 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
185189
186190 Type newConvOutputType = RankedTensorType::get (
187191 llvm::SmallVector<int64_t , 4 >(4 , ShapedType::kDynamic ),
188- resultType. cast <ShapedType>().getElementType ());
192+ cast<ShapedType>(resultType ).getElementType ());
189193
190194 conv2D = tosa::CreateOpAndInfer<mlir::tosa::DepthwiseConv2DOp>(rewriter,
191195 convOp->getLoc (), newConvOutputType, newInput, newWeight, bias,
192- newPads, strides, dilations);
196+ newPads, strides, dilations, accType );
193197 } else if (group <= groupedConvThreshold) {
194198 // Decompose group convolution into a concatenation of tosa.conv2d ops
195199 // can be costly, so only allow it when the number of groups is less
196200 // than configurable threshold.
197201
198202 conv2D = createConvInGroups (rewriter, convOp, tosaBuilder, resultType,
199203 weightShape, newInput, newWeight, bias, group, newPads, strides,
200- dilations);
204+ dilations, accType );
201205 } else {
202206 return rewriter.notifyMatchFailure (
203207 op, " this type of grouped Conv is not supported" );
0 commit comments