@@ -88,15 +88,14 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
8888 .getResult (0 );
8989}
9090
91- // Broadcast the source value to all the outer dimensions of the result value.
92- // If required, the element type is expanded using an arith.extsi operation.
93- static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
94- Location loc, Value source,
95- Value result) {
91+ // Construct the affine map that a linalg generic would use to broadcast the
92+ // source tensor into the shape of the result tensor.
93+ static AffineMap getBroadcastingMap (PatternRewriter &rewriter, Value source,
94+ Value result) {
9695 ShapedType resultTy = cast<ShapedType>(result.getType ());
9796 ShapedType sourceTy = cast<ShapedType>(source.getType ());
98- int64_t resultRank = resultTy.getRank ();
99- int64_t sourceRank = sourceTy.getRank ();
97+ const int64_t resultRank = resultTy.getRank ();
98+ const int64_t sourceRank = sourceTy.getRank ();
10099
101100 // The source tensor is broadcast to all the outer dimensions of the
102101 // result tensor.
@@ -115,14 +114,21 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
115114 }
116115 }
117116
118- // Creating maps for the input and output of the broacast-like generic op.
119- SmallVector<AffineMap, 2 > indexingMaps = {
120- // Broadcast the last dimension of the bias to all output dimensions.
121- AffineMap::get (/* dimCount=*/ resultRank,
122- /* symbolCount=*/ 0 , sourceDims, rewriter.getContext ()),
117+ return AffineMap::get (/* dimCount=*/ resultRank,
118+ /* symbolCount=*/ 0 , sourceDims, rewriter.getContext ());
119+ }
123120
124- // Output indexing map.
125- rewriter.getMultiDimIdentityMap (resultRank)};
121+ // Broadcast the source value to all the outer dimensions of the result value.
122+ // If required, the element type is expanded using an arith.extsi operation.
123+ static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
124+ Location loc, Value source,
125+ Value result) {
126+ ShapedType resultTy = cast<ShapedType>(result.getType ());
127+ const int64_t resultRank = resultTy.getRank ();
128+ // Creating maps for the input and output of the broacast-like generic op.
129+ SmallVector<AffineMap, 2 > indexingMaps;
130+ indexingMaps.push_back (getBroadcastingMap (rewriter, source, result));
131+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
126132
127133 // Build the broadcast-like operation as a linalg.generic.
128134 return rewriter
@@ -488,14 +494,6 @@ class DepthwiseConvConverter
488494 weightShape[2 ], weightShape[3 ]},
489495 resultETy);
490496
491- // Broadcast the initial value to the output tensor before convolving.
492- SmallVector<AffineMap, 4 > indexingMaps;
493- indexingMaps.push_back (AffineMap::get (
494- /* dimCount=*/ resultRank, /* symbolCount=*/ 0 ,
495- {rewriter.getAffineDimExpr (3 )}, rewriter.getContext ()));
496- indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
497- indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
498-
499497 auto resultZeroAttr = rewriter.getZeroAttr (resultETy);
500498 Value emptyTensor = rewriter.create <tensor::EmptyOp>(
501499 loc, linalgConvTy.getShape (), resultETy, filteredDims);
@@ -507,6 +505,13 @@ class DepthwiseConvConverter
507505
508506 Value biasEmptyTensor = rewriter.create <tensor::EmptyOp>(
509507 loc, resultTy.getShape (), resultETy, filteredDims);
508+
509+ // Broadcast the initial value to the output tensor before convolving.
510+ SmallVector<AffineMap, 4 > indexingMaps;
511+ indexingMaps.push_back (getBroadcastingMap (rewriter, bias, biasEmptyTensor));
512+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
513+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
514+
510515 if (!isQuantized) {
511516 Value conv = rewriter
512517 .create <linalg::DepthwiseConv2DNhwcHwcmOp>(
0 commit comments