@@ -649,24 +649,34 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
649649 return getShapes ().front ();
650650 }
651651
652- // TODO: Support folding with more than 2 input shapes
653- if (getShapes ().size () > 2 )
652+ if (!adaptor.getShapes ().front ())
654653 return nullptr ;
655654
656- if (!adaptor.getShapes ()[0 ] || !adaptor.getShapes ()[1 ])
657- return nullptr ;
658- auto lhsShape = llvm::to_vector<6 >(
659- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[0 ])
660- .getValues <int64_t >());
661- auto rhsShape = llvm::to_vector<6 >(
662- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[1 ])
655+ auto firstShape = llvm::to_vector<6 >(
656+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ().front ())
663657 .getValues <int64_t >());
658+
664659 SmallVector<int64_t , 6 > resultShape;
660+ resultShape.clear ();
661+ std::copy (firstShape.begin (), firstShape.end (),
662+ std::back_inserter (resultShape));
665663
666- // If the shapes are not compatible, we can't fold it.
667- // TODO: Fold to an "error".
668- if (!OpTrait::util::getBroadcastedShape (lhsShape, rhsShape, resultShape))
669- return nullptr ;
664+ for (auto next : adaptor.getShapes ().drop_front ()) {
665+ if (!next)
666+ return nullptr ;
667+ auto nextShape = llvm::to_vector<6 >(
668+ llvm::cast<DenseIntElementsAttr>(next).getValues <int64_t >());
669+
670+ SmallVector<int64_t , 6 > tmpShape;
671+ // If the shapes are not compatible, we can't fold it.
672+ // TODO: Fold to an "error".
673+ if (!OpTrait::util::getBroadcastedShape (resultShape, nextShape, tmpShape))
674+ return nullptr ;
675+
676+ resultShape.clear ();
677+ std::copy (tmpShape.begin (), tmpShape.end (),
678+ std::back_inserter (resultShape));
679+ }
670680
671681 Builder builder (getContext ());
672682 return builder.getIndexTensorAttr (resultShape);
0 commit comments