@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
649649 return getShapes ().front ();
650650 }
651651
652- // TODO: Support folding with more than 2 input shapes
653- if (getShapes ().size () > 2 )
652+ if (!adaptor.getShapes ().front ())
654653 return nullptr ;
655654
656- if (!adaptor.getShapes ()[0 ] || !adaptor.getShapes ()[1 ])
657- return nullptr ;
658- auto lhsShape = llvm::to_vector<6 >(
659- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[0 ])
660- .getValues <int64_t >());
661- auto rhsShape = llvm::to_vector<6 >(
662- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[1 ])
655+ SmallVector<int64_t , 6 > resultShape (
656+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ().front ())
663657 .getValues <int64_t >());
664- SmallVector<int64_t , 6 > resultShape;
665658
666- // If the shapes are not compatible, we can't fold it.
667- // TODO: Fold to an "error".
668- if (!OpTrait::util::getBroadcastedShape (lhsShape, rhsShape, resultShape))
669- return nullptr ;
659+ for (auto next : adaptor.getShapes ().drop_front ()) {
660+ if (!next)
661+ return nullptr ;
662+ auto nextShape = llvm::to_vector<6 >(
663+ llvm::cast<DenseIntElementsAttr>(next).getValues <int64_t >());
664+
665+ SmallVector<int64_t , 6 > tmpShape;
666+ // If the shapes are not compatible, we can't fold it.
667+ // TODO: Fold to an "error".
668+ if (!OpTrait::util::getBroadcastedShape (resultShape, nextShape, tmpShape))
669+ return nullptr ;
670+
671+ resultShape.clear ();
672+ std::copy (tmpShape.begin (), tmpShape.end (),
673+ std::back_inserter (resultShape));
674+ }
670675
671676 Builder builder (getContext ());
672677 return builder.getIndexTensorAttr (resultShape);
0 commit comments