Skip to content

Commit 5946db5

Browse files
committed
Address comments
1 parent 0a7f4c9 commit 5946db5

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6594,24 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
65946594
return true;
65956595
}
65966596

6597-
static Attribute createBoolSplat(ShapedType ty, bool x) {
6598-
return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
6599-
}
6600-
66016597
OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
66026598
ArrayRef<int64_t> bounds = getMaskDimSizes();
66036599
ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
6600+
6601+
auto createBoolSplat = [&](bool x) {
6602+
return SplatElementsAttr::get(getVectorType(),
6603+
BoolAttr::get(getContext(), x));
6604+
};
6605+
66046606
// Check the corner case of 0-D vectors first.
6605-
if (vectorSizes.size() == 0) {
6607+
if (vectorSizes.empty()) {
66066608
assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
6607-
return createBoolSplat(getVectorType(), bounds[0] == 1);
6609+
return createBoolSplat(bounds[0] == 1);
66086610
}
66096611
// Fold vector.constant_mask to splat if possible.
66106612
if (bounds == vectorSizes)
6611-
return createBoolSplat(getVectorType(), true);
6613+
return createBoolSplat(true);
66126614
if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
6613-
return createBoolSplat(getVectorType(), false);
6614-
return {};
6615+
return createBoolSplat(false);
6616+
return OpFoldResult();
66156617
}
66166618

66176619
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)