Skip to content

Commit 7de533b

Browse files
committed
Resolve PR reviews
1 parent 4ceae74 commit 7de533b

File tree

3 files changed

+38
-47
lines changed

3 files changed

+38
-47
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,32 @@ def CIR_ConstantOp : CIR_Op<"const", [
425425
return boolAttr.getValue();
426426
llvm_unreachable("Expected a BoolAttr in ConstantOp");
427427
}
428+
static bool isAllOnesValue(mlir::Value value) {
429+
auto constOp = mlir::dyn_cast_or_null<cir::ConstantOp>(value.getDefiningOp());
430+
if (!constOp)
431+
return false;
432+
433+
// Check for -1 integers
434+
if (auto intAttr = constOp.getValueAttr<cir::IntAttr>())
435+
return intAttr.getValue().isAllOnes();
436+
437+
// Check for FP which are bitcasted from -1 integers
438+
if (auto fpAttr = constOp.getValueAttr<cir::FPAttr>())
439+
return fpAttr.getValue().bitcastToAPInt().isAllOnes();
440+
441+
442+
// Check for constant vectors with splat values
443+
if (cir::VectorType v = mlir::dyn_cast<cir::VectorType>(constOp.getType()))
444+
if (auto vecAttr = constOp.getValueAttr<mlir::DenseElementsAttr>())
445+
if (vecAttr.isSplat()) {
446+
auto splatAttr = vecAttr.getSplatValue<mlir::Attribute>();
447+
if (auto splatInt = mlir::dyn_cast<cir::IntAttr>(splatAttr)) {
448+
return splatInt.getValue().isAllOnes();
449+
}
450+
}
451+
452+
return false;
453+
}
428454
}];
429455

430456
let hasFolder = 1;
@@ -1885,8 +1911,12 @@ def CIR_SelectOp : CIR_Op<"select", [
18851911
```
18861912
}];
18871913

1888-
let arguments = (ins CIR_ScalarOrVectorOf<CIR_BoolType>:$condition, CIR_AnyType:$true_value,
1889-
CIR_AnyType:$false_value);
1914+
let arguments = (ins
1915+
CIR_ScalarOrVectorOf<CIR_BoolType>:$condition,
1916+
CIR_AnyType:$true_value,
1917+
CIR_AnyType:$false_value
1918+
);
1919+
18901920
let results = (outs CIR_AnyType:$result);
18911921

18921922
let assemblyFormat = [{

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,8 @@ class CIR_VectorTypeOf<list<Type> types, string summary = "">
266266
"vector of " # CIR_TypeSummaries<types>.value,
267267
summary)>;
268268

269-
class CIR_VectorOf<Type T> : CIR_ConfinedType<
270-
CIR_AnyVectorType,
271-
[CIR_ElementTypePred<T.predicate>],
272-
"CIR vector of " # T.summary>;
273-
274269
// Type constraint accepting a either a type T or a vector of type T
275-
// Mimicking LLVMIR's LLVM_ScalarOrVectorOf
276-
class CIR_ScalarOrVectorOf<Type T> :
277-
AnyTypeOf<[T, CIR_VectorOf<T>]>;
270+
class CIR_ScalarOrVectorOf<Type T> : AnyTypeOf<[T, CIR_VectorTypeOf<[T]>]>;
278271

279272
// Vector of integral type
280273
def IntegerVector : Type<

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder,
163163

164164
if (numElems < 8) {
165165
SmallVector<mlir::Attribute, 4> indices;
166+
indices.reserve(numElems);
166167
mlir::Type i32Ty = builder.getSInt32Ty();
167168
for (auto i : llvm::seq<unsigned>(0, numElems))
168169
indices.push_back(cir::IntAttr::get(i32Ty, i));
@@ -172,43 +173,11 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder,
172173
return maskVec;
173174
}
174175

175-
// Helper function mirroring OG's bool Constant::isAllOnesValue()
176-
static bool isAllOnesValue(mlir::Value value) {
177-
auto constOp = mlir::dyn_cast_or_null<cir::ConstantOp>(value.getDefiningOp());
178-
if (!constOp)
179-
return false;
180-
181-
// Check for -1 integers
182-
if (auto intAttr = constOp.getValueAttr<cir::IntAttr>()) {
183-
return intAttr.getValue().isAllOnes();
184-
}
185-
186-
// Check for FP which are bitcasted from -1 integers
187-
if (auto fpAttr = constOp.getValueAttr<cir::FPAttr>()) {
188-
return fpAttr.getValue().bitcastToAPInt().isAllOnes();
189-
}
190-
191-
// Check for constant vectors with splat values
192-
if (cir::VectorType v = dyn_cast<cir::VectorType>(constOp.getType())) {
193-
if (auto vecAttr = constOp.getValueAttr<mlir::DenseElementsAttr>()) {
194-
if (vecAttr.isSplat()) {
195-
auto splatAttr = vecAttr.getSplatValue<mlir::Attribute>();
196-
if (auto splatInt = mlir::dyn_cast<cir::IntAttr>(splatAttr)) {
197-
return splatInt.getValue().isAllOnes();
198-
}
199-
}
200-
}
201-
}
202-
203-
return false;
204-
}
205-
206176
static mlir::Value emitX86Select(CIRGenBuilderTy &builder, mlir::Location loc,
207177
mlir::Value mask, mlir::Value op0,
208178
mlir::Value op1) {
209-
210179
// If the mask is all ones just return first argument.
211-
if (isAllOnesValue(mask))
180+
if (cir::ConstantOp::isAllOnesValue(mask))
212181
return op0;
213182

214183
mask = getBoolMaskVecValue(builder, loc, mask,
@@ -958,22 +927,21 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
958927
unsigned numElts = dstTy.getSize();
959928
unsigned srcNumElts = cast<cir::VectorType>(ops[0].getType()).getSize();
960929
unsigned subVectors = srcNumElts / numElts;
930+
assert(llvm::isPowerOf2_32(subVectors) && "Expected power of 2 subvectors");
961931
unsigned index =
962932
ops[1].getDefiningOp<cir::ConstantOp>().getIntValue().getZExtValue();
963933

964934
index &= subVectors - 1; // Remove any extra bits.
965935
index *= numElts;
966936

967937
int64_t indices[16];
968-
for (unsigned i = 0; i != numElts; ++i)
969-
indices[i] = i + index;
938+
std::iota(indices, indices + numElts, index);
970939

971940
mlir::Value zero = builder.getNullValue(ops[0].getType(), loc);
972941
mlir::Value res =
973942
builder.createVecShuffle(loc, ops[0], zero, ArrayRef(indices, numElts));
974-
if (ops.size() == 4) {
943+
if (ops.size() == 4)
975944
res = emitX86Select(builder, loc, ops[3], res, ops[2]);
976-
}
977945

978946
return res;
979947
}

0 commit comments

Comments
 (0)