@@ -4237,28 +4237,35 @@ class StridedSliceBroadcast final
42374237 auto dstVecType = llvm::cast<VectorType>(op.getType ());
42384238 unsigned dstRank = dstVecType.getRank ();
42394239 unsigned rankDiff = dstRank - srcRank;
4240- // Check if the most inner dimensions of the source of the broadcast are the
4241- // same as the destination of the extract . If this is the case we can just
4242- // use a broadcast as the original dimensions are untouched .
4243- bool lowerDimMatch = true ;
4240+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4241+ // (n -> m with n > m) . If they are originally both broadcasted *and*
4242+ // sliced, this can be simplified to just broadcasting .
4243+ bool needsSlice = false ;
42444244 for (unsigned i = 0 ; i < srcRank; i++) {
4245- if (srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4246- lowerDimMatch = false ;
4245+ if (srcVecType.getDimSize (i) != 1 &&
4246+ srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4247+ needsSlice = true ;
42474248 break ;
42484249 }
42494250 }
42504251 Value source = broadcast.getSource ();
4251- // If the inner dimensions don't match, it means we need to extract from the
4252- // source of the orignal broadcast and then broadcast the extracted value.
4253- // We also need to handle degenerated cases where the source is effectively
4254- // just a single scalar.
4255- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements () == 1 );
4256- if (!lowerDimMatch && !isScalarSrc) {
4252+ if (needsSlice) {
4253+ SmallVector<int64_t > offsets =
4254+ getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff);
4255+ SmallVector<int64_t > sizes =
4256+ getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff);
4257+ for (unsigned i = 0 ; i < srcRank; i++) {
4258+ if (srcVecType.getDimSize (i) == 1 ) {
4259+ // In case this dimension was broadcasted *and* sliced, the offset
4260+ // and size need to be updated now that there is no broadcast before
4261+ // the slice.
4262+ offsets[i] = 0 ;
4263+ sizes[i] = 1 ;
4264+ }
4265+ }
42574266 source = rewriter.create <ExtractStridedSliceOp>(
4258- op->getLoc (), source,
4259- getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff),
4260- getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff),
4261- getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4267+ op->getLoc (), source, offsets, sizes,
4268+ getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
42624269 }
42634270 rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), source);
42644271 return success ();
0 commit comments