@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
41694169 auto dstVecType = llvm::cast<VectorType>(op.getType ());
41704170 unsigned dstRank = dstVecType.getRank ();
41714171 unsigned rankDiff = dstRank - srcRank;
4172- // Check if the most inner dimensions of the source of the broadcast are the
4173- // same as the destination of the extract . If this is the case we can just
4174- // use a broadcast as the original dimensions are untouched .
4175- bool lowerDimMatch = true ;
4172+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4173+ // (n -> m with n > m) . If they are originally both broadcasted *and*
4174+ // sliced, this can be simplified to just broadcasting .
4175+ bool needsSlice = false ;
41764176 for (unsigned i = 0 ; i < srcRank; i++) {
4177- if (srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4178- lowerDimMatch = false ;
4177+ if (srcVecType.getDimSize (i) != 1 &&
4178+ srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4179+ needsSlice = true ;
41794180 break ;
41804181 }
41814182 }
41824183 Value source = broadcast.getSource ();
4183- // If the inner dimensions don't match, it means we need to extract from the
4184- // source of the orignal broadcast and then broadcast the extracted value.
4185- // We also need to handle degenerated cases where the source is effectively
4186- // just a single scalar.
4187- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements () == 1 );
4188- if (!lowerDimMatch && !isScalarSrc) {
4184+ if (needsSlice) {
4185+ SmallVector<int64_t > offsets =
4186+ getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff);
4187+ SmallVector<int64_t > sizes =
4188+ getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff);
4189+ for (unsigned i = 0 ; i < srcRank; i++) {
4190+ if (srcVecType.getDimSize (i) == 1 ) {
4191+ // In case this dimension was broadcasted *and* sliced, the offset
4192+ // and size need to be updated now that there is no broadcast before
4193+ // the slice.
4194+ offsets[i] = 0 ;
4195+ sizes[i] = 1 ;
4196+ }
4197+ }
41894198 source = rewriter.create <ExtractStridedSliceOp>(
4190- op->getLoc (), source,
4191- getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff),
4192- getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff),
4193- getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4199+ op->getLoc (), source, offsets, sizes,
4200+ getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
41944201 }
41954202 rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), source);
41964203 return success ();
0 commit comments