Skip to content

Commit 8ae8f7b

Browse files
authored
merge main into amd-staging (llvm#1756)
2 parents 5c41d86 + a2a0e99 commit 8ae8f7b

File tree

24 files changed

+331
-174
lines changed

24 files changed

+331
-174
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
5757
public:
5858
CIRBaseBuilderTy(mlir::MLIRContext &mlirContext)
5959
: mlir::OpBuilder(&mlirContext) {}
60+
CIRBaseBuilderTy(mlir::OpBuilder &builder) : mlir::OpBuilder(builder) {}
6061

6162
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
6263
const llvm::APInt &val) {
@@ -98,13 +99,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
9899
if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))
99100
return getZeroAttr(recordTy);
100101
if (mlir::isa<cir::BoolType>(ty)) {
101-
return getCIRBoolAttr(false);
102+
return getFalseAttr();
102103
}
103104
llvm_unreachable("Zero initializer for given type is NYI");
104105
}
105106

106107
cir::ConstantOp getBool(bool state, mlir::Location loc) {
107-
return create<cir::ConstantOp>(loc, getBoolTy(), getCIRBoolAttr(state));
108+
return create<cir::ConstantOp>(loc, getCIRBoolAttr(state));
108109
}
109110
cir::ConstantOp getFalse(mlir::Location loc) { return getBool(false, loc); }
110111
cir::ConstantOp getTrue(mlir::Location loc) { return getBool(true, loc); }
@@ -120,9 +121,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
120121
}
121122

122123
cir::BoolAttr getCIRBoolAttr(bool state) {
123-
return cir::BoolAttr::get(getContext(), getBoolTy(), state);
124+
return cir::BoolAttr::get(getContext(), state);
124125
}
125126

127+
cir::BoolAttr getTrueAttr() { return getCIRBoolAttr(true); }
128+
cir::BoolAttr getFalseAttr() { return getCIRBoolAttr(false); }
129+
126130
mlir::Value createNot(mlir::Value value) {
127131
return create<cir::UnaryOp>(value.getLoc(), value.getType(),
128132
cir::UnaryOpKind::Not, value);

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def CIR_BoolAttr : CIR_Attr<"Bool", "bool", [TypedAttrInterface]> {
4949
"", "cir::BoolType">:$type,
5050
"bool":$value);
5151

52+
let builders = [
53+
AttrBuilder<(ins "bool":$value), [{
54+
return $_get($_ctxt, cir::BoolType::get($_ctxt), value);
55+
}]>,
56+
];
57+
5258
let assemblyFormat = [{
5359
`<` $value `>`
5460
}];

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def ConstantOp : CIR_Op<"const",
294294
// The constant operation returns a single value of CIR_AnyType.
295295
let results = (outs CIR_AnyType:$res);
296296

297+
let builders = [
298+
OpBuilder<(ins "cir::BoolAttr":$value), [{
299+
build($_builder, $_state, value.getType(), value);
300+
}]>
301+
];
302+
297303
let assemblyFormat = "attr-dict $value";
298304

299305
let hasVerifier = 1;
@@ -844,7 +850,7 @@ def UnaryOp : CIR_Op<"unary", [Pure, SameOperandsAndResultType]> {
844850
let assemblyFormat = [{
845851
`(` $kind `,` $input `)`
846852
(`nsw` $no_signed_wrap^)?
847-
`:` type($input) `,` type($result) attr-dict
853+
`:` type($input) `,` type($result) attr-dict
848854
}];
849855

850856
let hasVerifier = 1;

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
152152
}
153153

154154
mlir::Value VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *e) {
155-
mlir::Type type = cgf.convertType(e->getType());
156-
return builder.create<cir::ConstantOp>(
157-
cgf.getLoc(e->getExprLoc()), type,
158-
builder.getCIRBoolAttr(e->getValue()));
155+
return builder.getBool(e->getValue(), cgf.getLoc(e->getExprLoc()));
159156
}
160157

161158
mlir::Value VisitCastExpr(CastExpr *e);
@@ -215,9 +212,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
215212

216213
if (llvm::isa<MemberPointerType>(srcType)) {
217214
cgf.getCIRGenModule().errorNYI(loc, "member pointer to bool conversion");
218-
mlir::Type boolType = builder.getBoolTy();
219-
return builder.create<cir::ConstantOp>(loc, boolType,
220-
builder.getCIRBoolAttr(false));
215+
return builder.getFalse(loc);
221216
}
222217

223218
if (srcType->isIntegerType())
@@ -354,9 +349,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
354349
// An interesting aspect of this is that increment is always true.
355350
// Decrement does not have this property.
356351
if (isInc && type->isBooleanType()) {
357-
value = builder.create<cir::ConstantOp>(cgf.getLoc(e->getExprLoc()),
358-
cgf.convertType(type),
359-
builder.getCIRBoolAttr(true));
352+
value = builder.getTrue(cgf.getLoc(e->getExprLoc()));
360353
} else if (type->isIntegerType()) {
361354
QualType promotedType;
362355
bool canPerformLossyDemotionCheck = false;

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,7 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
456456
// scalar type.
457457
condVal = evaluateExprAsBool(s.getCond());
458458
} else {
459-
cir::BoolType boolTy = cir::BoolType::get(b.getContext());
460-
condVal = b.create<cir::ConstantOp>(
461-
loc, boolTy, cir::BoolAttr::get(b.getContext(), boolTy, true));
459+
condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr());
462460
}
463461
builder.createCondition(condVal);
464462
},

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
692692
// during a pass as long as they don't live past the end of the pass.
693693
attr = op.getValue();
694694
} else if (mlir::isa<cir::BoolType>(op.getType())) {
695-
int value = (op.getValue() ==
696-
cir::BoolAttr::get(getContext(),
697-
cir::BoolType::get(getContext()), true));
695+
int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
698696
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
699697
value);
700698
} else if (mlir::isa<cir::IntType>(op.getType())) {

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,13 +2114,11 @@ void AMDGPUCodeGenPassBuilder::addCodeGenPrepare(AddIRPass &addPass) const {
21142114

21152115
void AMDGPUCodeGenPassBuilder::addPreISel(AddIRPass &addPass) const {
21162116

2117-
if (TM.getOptLevel() > CodeGenOptLevel::None)
2117+
if (TM.getOptLevel() > CodeGenOptLevel::None) {
21182118
addPass(FlattenCFGPass());
2119-
2120-
if (TM.getOptLevel() > CodeGenOptLevel::None)
21212119
addPass(SinkingPass());
2122-
2123-
addPass(AMDGPULateCodeGenPreparePass(TM));
2120+
addPass(AMDGPULateCodeGenPreparePass(TM));
2121+
}
21242122

21252123
// Merge divergent exit nodes. StructurizeCFG won't recognize the multi-exit
21262124
// regions formed by them.

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,51 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
18761876
return false;
18771877
}
18781878

1879+
// Lower BUILD_VECTOR as broadcast load (if possible).
1880+
// For example:
1881+
// %a = load i8, ptr %ptr
1882+
// %b = build_vector %a, %a, %a, %a
1883+
// is lowered to :
1884+
// (VLDREPL_B $a0, 0)
1885+
static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
1886+
const SDLoc &DL,
1887+
SelectionDAG &DAG) {
1888+
MVT VT = BVOp->getSimpleValueType(0);
1889+
int NumOps = BVOp->getNumOperands();
1890+
1891+
assert((VT.is128BitVector() || VT.is256BitVector()) &&
1892+
"Unsupported vector type for broadcast.");
1893+
1894+
SDValue IdentitySrc;
1895+
bool IsIdeneity = true;
1896+
1897+
for (int i = 0; i != NumOps; i++) {
1898+
SDValue Op = BVOp->getOperand(i);
1899+
if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
1900+
IsIdeneity = false;
1901+
break;
1902+
}
1903+
IdentitySrc = BVOp->getOperand(0);
1904+
}
1905+
1906+
// make sure that this load is valid and only has one user.
1907+
if (!IdentitySrc || !BVOp->isOnlyUserOf(IdentitySrc.getNode()))
1908+
return SDValue();
1909+
1910+
if (IsIdeneity) {
1911+
auto *LN = cast<LoadSDNode>(IdentitySrc);
1912+
SDVTList Tys =
1913+
LN->isIndexed()
1914+
? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
1915+
: DAG.getVTList(VT, MVT::Other);
1916+
SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
1917+
SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
1918+
DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
1919+
return BCast;
1920+
}
1921+
return SDValue();
1922+
}
1923+
18791924
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
18801925
SelectionDAG &DAG) const {
18811926
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1891,6 +1936,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
18911936
(!Subtarget.hasExtLASX() || !Is256Vec))
18921937
return SDValue();
18931938

1939+
if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
1940+
return Result;
1941+
18941942
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
18951943
/*MinSplatBits=*/8) &&
18961944
SplatBitSize <= 64) {
@@ -5326,6 +5374,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
53265374
NODE_NAME_CASE(VSRLI)
53275375
NODE_NAME_CASE(VBSLL)
53285376
NODE_NAME_CASE(VBSRL)
5377+
NODE_NAME_CASE(VLDREPL)
53295378
}
53305379
#undef NODE_NAME_CASE
53315380
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
155155

156156
// Vector byte logicial left / right shift
157157
VBSLL,
158-
VBSRL
158+
VBSRL,
159+
160+
// Scalar load broadcast to vector
161+
VLDREPL
159162

160163
// Intrinsic operations end =============================================
161164
};

llvm/lib/Target/LoongArch/LoongArchInstrInfo.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
307307
}
308308
}
309309

310-
def simm9_lsl3 : Operand<GRLenVT> {
310+
def simm9_lsl3 : Operand<GRLenVT>,
311+
ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
311312
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
312313
let EncoderMethod = "getImmOpValueAsr<3>";
313314
let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
317318
let ParserMatchClass = SImmAsmOperand<10>;
318319
}
319320

320-
def simm10_lsl2 : Operand<GRLenVT> {
321+
def simm10_lsl2 : Operand<GRLenVT>,
322+
ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
321323
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
322324
let EncoderMethod = "getImmOpValueAsr<2>";
323325
let DecoderMethod = "decodeSImmOperand<10, 2>";
324326
}
325327

326-
def simm11_lsl1 : Operand<GRLenVT> {
328+
def simm11_lsl1 : Operand<GRLenVT>,
329+
ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
327330
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
328331
let EncoderMethod = "getImmOpValueAsr<1>";
329332
let DecoderMethod = "decodeSImmOperand<11, 1>";

0 commit comments

Comments
 (0)