@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
119119}
120120
121121// Broadcast the source value to all the outer dimensions of the result value.
122- // If required, the element type is expanded using an arith.extsi operation.
123- static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
124- Location loc, Value source,
125- Value result) {
122+ // If required, the element type is expanded using an arith.extsi or arith.extf
123+ // operation as appropriate.
124+ static mlir::Value linalgBroadcastAndMaybeExt (PatternRewriter &rewriter,
125+ Location loc, Value source,
126+ Value result) {
126127 ShapedType resultTy = cast<ShapedType>(result.getType ());
127128 const int64_t resultRank = resultTy.getRank ();
128129 // Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
135136 .create <linalg::GenericOp>(
136137 loc, resultTy, ValueRange ({source}), result, indexingMaps,
137138 getNParallelLoopsAttrs (resultTy.getRank ()),
138- [](OpBuilder &builder, Location loc, ValueRange args) {
139+ [&resultTy ](OpBuilder &builder, Location loc, ValueRange args) {
139140 Value biasVal = args[0 ];
140141 Type resType = args[1 ].getType ();
141142 if (resType != biasVal.getType ()) {
142- biasVal = builder.create <arith::ExtSIOp>(loc, resType, biasVal);
143+ biasVal =
144+ resultTy.getElementType ().isFloat ()
145+ ? builder.create <arith::ExtFOp>(loc, resType, biasVal)
146+ .getResult ()
147+ : builder.create <arith::ExtSIOp>(loc, resType, biasVal)
148+ .getResult ();
143149 }
144150 builder.create <linalg::YieldOp>(loc, biasVal);
145151 })
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
253259 ShapedType resultTy = cast<ShapedType>(op->getResult (0 ).getType ());
254260
255261 Type inputETy = inputTy.getElementType ();
256- Type resultETy = resultTy.getElementType ();
257262
258263 DenseI64ArrayAttr padAttr = op.getPadAttr ();
259264 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260265 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261266
267+ Type accETy = op.getAccType ();
268+ Type accTy = RankedTensorType::get (resultTy.getShape (), accETy);
269+
262270 // Get and verify zero points.
263271 FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
264272 if (failed (maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
385393 auto dilationAttr = rewriter.getI64TensorAttr (dilation);
386394
387395 Value biasEmptyTensor = rewriter.create <tensor::EmptyOp>(
388- loc, resultTy.getShape (), resultETy , filteredDims);
396+ loc, resultTy.getShape (), accETy , filteredDims);
389397
390398 Value broadcastBias =
391- linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
399+ linalgBroadcastAndMaybeExt (rewriter, loc, bias, biasEmptyTensor);
392400
393401 if (hasZp) {
394402 auto iZp = rewriter.getI32IntegerAttr (inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
410418
411419 Value conv = rewriter
412420 .create <LinalgConvOp>(
413- loc, resultTy , ValueRange{input, weight},
421+ loc, accTy , ValueRange{input, weight},
414422 ValueRange{broadcastBias}, strideAttr, dilationAttr)
415423 ->getResult (0 );
416424
425+ // We may need to truncate back to the result type if the accumulator was
426+ // wider than the result.
427+ if (resultTy != accTy)
428+ conv = rewriter.create <tosa::CastOp>(loc, resultTy, conv);
429+
417430 rewriter.replaceOp (op, conv);
418431 return success ();
419432 }
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
444457 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr (" stride" ));
445458 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr (" dilation" ));
446459
460+ Type accETy = op.getAccType ();
461+
447462 if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
448463 return rewriter.notifyMatchFailure (
449464 op, " tosa.depthwise_conv ops require static shapes" );
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
516531 ShapedType linalgConvTy =
517532 RankedTensorType::get ({resultShape[0 ], resultShape[1 ], resultShape[2 ],
518533 weightShape[2 ], weightShape[3 ]},
519- resultETy );
534+ accETy );
520535
521- auto resultZeroAttr = rewriter.getZeroAttr (resultETy );
536+ auto resultZeroAttr = rewriter.getZeroAttr (accETy );
522537 Value emptyTensor = rewriter.create <tensor::EmptyOp>(
523- loc, linalgConvTy.getShape (), resultETy , filteredDims);
538+ loc, linalgConvTy.getShape (), accETy , filteredDims);
524539 Value zero = rewriter.create <arith::ConstantOp>(loc, resultZeroAttr);
525540 Value zeroTensor = rewriter
526541 .create <linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
543558 ValueRange{zeroTensor}, strideAttr, dilationAttr)
544559 .getResult (0 );
545560
561+ // We may need to truncate back to the result type if the accumulator was
562+ // wider than the result.
563+ if (accETy != resultETy)
564+ conv = rewriter.create <tosa::CastOp>(
565+ loc,
566+ RankedTensorType::get (cast<ShapedType>(conv.getType ()).getShape (),
567+ resultETy),
568+ conv);
569+
546570 SmallVector<ReassociationExprs, 4 > reassociationMap;
547571 createDepthwiseConvCollapseMap (resultRank, reassociationMap, rewriter);
548572 Value convReshape = rewriter.create <tensor::CollapseShapeOp>(
0 commit comments