@@ -6594,24 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
65946594 return true ;
65956595}
65966596
6597- static Attribute createBoolSplat (ShapedType ty, bool x) {
6598- return SplatElementsAttr::get (ty, BoolAttr::get (ty.getContext (), x));
6599- }
6600-
66016597OpFoldResult ConstantMaskOp::fold (FoldAdaptor adaptor) {
66026598 ArrayRef<int64_t > bounds = getMaskDimSizes ();
66036599 ArrayRef<int64_t > vectorSizes = getVectorType ().getShape ();
6600+
6601+ auto createBoolSplat = [&](bool x) {
6602+ return SplatElementsAttr::get (getVectorType (),
6603+ BoolAttr::get (getContext (), x));
6604+ };
6605+
66046606 // Check the corner case of 0-D vectors first.
6605- if (vectorSizes.size () == 0 ) {
6607+ if (vectorSizes.empty () ) {
66066608 assert (bounds.size () == 1 && " invalid sizes for zero rank mask" );
6607- return createBoolSplat (getVectorType (), bounds[0 ] == 1 );
6609+ return createBoolSplat (bounds[0 ] == 1 );
66086610 }
66096611 // Fold vector.constant_mask to splat if possible.
66106612 if (bounds == vectorSizes)
6611- return createBoolSplat (getVectorType (), true );
6613+ return createBoolSplat (true );
66126614 if (llvm::all_of (bounds, [](int64_t x) { return x == 0 ; }))
6613- return createBoolSplat (getVectorType (), false );
6614- return {} ;
6615+ return createBoolSplat (false );
6616+ return OpFoldResult () ;
66156617}
66166618
66176619// ===----------------------------------------------------------------------===//
0 commit comments