Skip to content

Commit 32285f4

Browse files
[CIR] Upstream three way compare op
1 parent e6f60a6 commit 32285f4

File tree

10 files changed

+759
-2
lines changed

10 files changed

+759
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,68 @@ def CIR_ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
447447
}];
448448
}
449449

450+
//===----------------------------------------------------------------------===//
451+
// CmpThreeWayInfoAttr
452+
//===----------------------------------------------------------------------===//
453+
454+
def CIR_CmpOrdering : CIR_I32EnumAttr<
455+
"CmpOrdering", "three-way comparison ordering kind", [
456+
I32EnumAttrCase<"Strong", 0, "strong">,
457+
I32EnumAttrCase<"Partial", 1, "partial">
458+
]> {
459+
let genSpecializedAttr = 0;
460+
}
461+
462+
def CIR_CmpThreeWayInfoAttr : CIR_Attr<"CmpThreeWayInfo", "cmp3way_info"> {
463+
let summary = "Holds information about a three-way comparison operation";
464+
let description = [{
465+
The `#cmp3way_info` attribute contains information about a three-way
466+
comparison operation `cir.cmp3way`.
467+
468+
The `ordering` parameter gives the ordering kind of the three-way comparison
469+
operation. It may be either strong ordering or partial ordering.
470+
471+
Given the two input operands of the three-way comparison operation `lhs` and
472+
`rhs`, the `lt`, `eq`, `gt`, and `unordered` parameters gives the result
473+
value that should be produced by the three-way comparison operation when the
474+
ordering between `lhs` and `rhs` is `lhs < rhs`, `lhs == rhs`, `lhs > rhs`,
475+
or neither, respectively.
476+
}];
477+
478+
let parameters = (ins
479+
EnumParameter<CIR_CmpOrdering>:$ordering,
480+
"int64_t":$lt, "int64_t":$eq, "int64_t":$gt,
481+
OptionalParameter<"std::optional<int64_t>">:$unordered
482+
);
483+
484+
let builders = [
485+
AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt), [{
486+
return $_get($_ctxt, CmpOrdering::Strong, lt, eq, gt, std::nullopt);
487+
}]>,
488+
AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt,
489+
"int64_t":$unordered), [{
490+
return $_get($_ctxt, CmpOrdering::Partial, lt, eq, gt, unordered);
491+
}]>,
492+
];
493+
494+
let extraClassDeclaration = [{
495+
/// Get attribute alias name for this attribute.
496+
std::string getAlias() const;
497+
}];
498+
499+
let assemblyFormat = [{
500+
`<`
501+
$ordering `,`
502+
`lt` `=` $lt `,`
503+
`eq` `=` $eq `,`
504+
`gt` `=` $gt
505+
(`,` `unordered` `=` $unordered^)?
506+
`>`
507+
}];
508+
509+
let genVerifyDecl = 1;
510+
}
511+
450512
//===----------------------------------------------------------------------===//
451513
// GlobalViewAttr
452514
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,71 @@ def CIR_ScopeOp : CIR_Op<"scope", [
958958
let hasLLVMLowering = false;
959959
}
960960

961+
//===----------------------------------------------------------------------===//
962+
// CmpThreeWayOp
963+
//===----------------------------------------------------------------------===//
964+
965+
def CIR_CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
966+
let summary = "Compare two values with C++ three-way comparison semantics";
967+
let description = [{
968+
The `cir.cmp3way` operation models the `<=>` operator in C++20. It takes two
969+
operands with the same type and produces a result indicating the ordering
970+
between the two input operands.
971+
972+
The result of the operation is a signed integer that indicates the ordering
973+
between the two input operands.
974+
975+
There are two kinds of ordering: strong ordering and partial ordering.
976+
Comparing different types of values yields different kinds of orderings.
977+
The `info` parameter gives the ordering kind and other necessary information
978+
about the comparison.
979+
980+
Example:
981+
982+
```mlir
983+
!s32i = !cir.int<s, 32>
984+
985+
#cmp3way_strong = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1>
986+
#cmp3way_partial = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1, unordered = 2>
987+
988+
%0 = cir.const #cir.int<0> : !s32i
989+
%1 = cir.const #cir.int<1> : !s32i
990+
%2 = cir.cmp3way(%0 : !s32i, %1, #cmp3way_strong) : !s8i
991+
992+
%3 = cir.const #cir.fp<0.0> : !cir.float
993+
%4 = cir.const #cir.fp<1.0> : !cir.float
994+
%5 = cir.cmp3way(%3 : !cir.float, %4, #cmp3way_partial) : !s8i
995+
```
996+
}];
997+
998+
let arguments = (ins
999+
CIR_AnyType:$lhs,
1000+
CIR_AnyType:$rhs,
1001+
CIR_CmpThreeWayInfoAttr:$info
1002+
);
1003+
1004+
let results = (outs CIR_AnySIntType:$result);
1005+
1006+
let assemblyFormat = [{
1007+
`(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)`
1008+
`:` type($result) attr-dict
1009+
}];
1010+
1011+
let extraClassDeclaration = [{
1012+
/// Determine whether this three-way comparison produces a strong ordering.
1013+
bool isStrongOrdering() {
1014+
return getInfo().getOrdering() == cir::CmpOrdering::Strong;
1015+
}
1016+
1017+
/// Determine whether this three-way comparison compares integral operands.
1018+
bool isIntegralComparison() {
1019+
return mlir::isa<cir::IntType>(getLhs().getType());
1020+
}
1021+
}];
1022+
1023+
let hasVerifier = 1;
1024+
}
1025+
9611026
//===----------------------------------------------------------------------===//
9621027
// SwitchOp
9631028
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,37 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
567567
return cir::StackRestoreOp::create(*this, loc, v);
568568
}
569569

570+
cir::CmpThreeWayOp createThreeWayCmpStrong(mlir::Location loc,
571+
mlir::Value lhs, mlir::Value rhs,
572+
const llvm::APSInt &ltRes,
573+
const llvm::APSInt &eqRes,
574+
const llvm::APSInt &gtRes) {
575+
assert(ltRes.getBitWidth() == eqRes.getBitWidth() &&
576+
ltRes.getBitWidth() == gtRes.getBitWidth() &&
577+
"the three comparison results must have the same bit width");
578+
cir::IntType cmpResultTy = getSIntNTy(ltRes.getBitWidth());
579+
auto infoAttr = cir::CmpThreeWayInfoAttr::get(
580+
getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), gtRes.getSExtValue());
581+
return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs,
582+
infoAttr);
583+
}
584+
585+
cir::CmpThreeWayOp
586+
createThreeWayCmpPartial(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
587+
const llvm::APSInt &ltRes, const llvm::APSInt &eqRes,
588+
const llvm::APSInt &gtRes,
589+
const llvm::APSInt &unorderedRes) {
590+
assert(ltRes.getBitWidth() == eqRes.getBitWidth() &&
591+
ltRes.getBitWidth() == gtRes.getBitWidth() &&
592+
ltRes.getBitWidth() == unorderedRes.getBitWidth() &&
593+
"the four comparison results must have the same bit width");
594+
auto cmpResultTy = getSIntNTy(ltRes.getBitWidth());
595+
auto infoAttr = cir::CmpThreeWayInfoAttr::get(
596+
getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), gtRes.getSExtValue(), unorderedRes.getSExtValue());
597+
return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs,
598+
infoAttr);
599+
}
600+
570601
mlir::Value createSetBitfield(mlir::Location loc, mlir::Type resultType,
571602
Address dstAddr, mlir::Type storageType,
572603
mlir::Value src, const CIRGenBitFieldInfo &info,

clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "clang/AST/Expr.h"
1919
#include "clang/AST/RecordLayout.h"
2020
#include "clang/AST/StmtVisitor.h"
21+
#include "llvm/IR/Value.h"
2122
#include <cstdint>
2223

2324
using namespace clang;
@@ -298,8 +299,63 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
298299
Visit(e->getRHS());
299300
}
300301
void VisitBinCmp(const BinaryOperator *e) {
301-
cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp");
302+
assert(cgf.getContext().hasSameType(e->getLHS()->getType(),
303+
e->getRHS()->getType()));
304+
const ComparisonCategoryInfo &cmpInfo =
305+
cgf.getContext().CompCategories.getInfoForType(e->getType());
306+
assert(cmpInfo.Record->isTriviallyCopyable() &&
307+
"cannot copy non-trivially copyable aggregate");
308+
309+
QualType argTy = e->getLHS()->getType();
310+
311+
if (!argTy->isIntegralOrEnumerationType() && !argTy->isRealFloatingType() &&
312+
!argTy->isNullPtrType() && !argTy->isPointerType() &&
313+
!argTy->isMemberPointerType() && !argTy->isAnyComplexType())
314+
llvm_unreachable("aggregate three-way comparison");
315+
316+
mlir::Location loc = cgf.getLoc(e->getSourceRange());
317+
CIRGenBuilderTy builder = cgf.getBuilder();
318+
319+
if (e->getType()->isAnyComplexType())
320+
llvm_unreachable("NYI");
321+
322+
mlir::Value lhs = cgf.emitAnyExpr(e->getLHS()).getValue();
323+
mlir::Value rhs = cgf.emitAnyExpr(e->getRHS()).getValue();
324+
325+
mlir::Value resultScalar;
326+
if (argTy->isNullPtrType()) {
327+
resultScalar =
328+
builder.getConstInt(loc, cmpInfo.getEqualOrEquiv()->getIntValue());
329+
} else {
330+
llvm::APSInt ltRes = cmpInfo.getLess()->getIntValue();
331+
llvm::APSInt eqRes = cmpInfo.getEqualOrEquiv()->getIntValue();
332+
llvm::APSInt gtRes = cmpInfo.getGreater()->getIntValue();
333+
if (!cmpInfo.isPartial()) {
334+
// Strong ordering.
335+
resultScalar = builder.createThreeWayCmpStrong(loc, lhs, rhs, ltRes,
336+
eqRes, gtRes);
337+
} else {
338+
// Partial ordering.
339+
llvm::APSInt unorderedRes = cmpInfo.getUnordered()->getIntValue();
340+
resultScalar = builder.createThreeWayCmpPartial(
341+
loc, lhs, rhs, ltRes, eqRes, gtRes, unorderedRes);
342+
}
343+
}
344+
345+
// Create the return value in the destination slot.
346+
ensureDest(loc, e->getType());
347+
LValue destLVal = cgf.makeAddrLValue(dest.getAddress(), e->getType());
348+
349+
// Emit the address of the first (and only) field in the comparison category
350+
// type, and initialize it from the constant integer value produced above.
351+
const FieldDecl *resultField = *cmpInfo.Record->field_begin();
352+
LValue fieldLVal = cgf.emitLValueForFieldInitialization(destLVal, resultField,
353+
resultField->getName());
354+
cgf.emitStoreThroughLValue(RValue::get(resultScalar), fieldLVal);
355+
356+
// All done! The result is in the dest slot.
302357
}
358+
303359
void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) {
304360
cgf.cgm.errorNYI(e->getSourceRange(),
305361
"AggExprEmitter: VisitCXXRewrittenBinaryOperator");

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,60 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
267267
return success();
268268
}
269269

270+
271+
//===----------------------------------------------------------------------===//
272+
// CmpThreeWayInfoAttr definitions
273+
//===----------------------------------------------------------------------===//
274+
275+
std::string CmpThreeWayInfoAttr::getAlias() const {
276+
std::string alias = "cmp3way_info";
277+
278+
if (getOrdering() == CmpOrdering::Strong)
279+
alias.append("_strong_");
280+
else
281+
alias.append("_partial_");
282+
283+
auto appendInt = [&](int64_t value) {
284+
if (value < 0) {
285+
alias.push_back('n');
286+
value = -value;
287+
}
288+
alias.append(std::to_string(value));
289+
};
290+
291+
alias.append("lt");
292+
appendInt(getLt());
293+
alias.append("eq");
294+
appendInt(getEq());
295+
alias.append("gt");
296+
appendInt(getGt());
297+
298+
if (std::optional<int> unordered = getUnordered()) {
299+
alias.append("un");
300+
appendInt(unordered.value());
301+
}
302+
303+
return alias;
304+
}
305+
306+
LogicalResult
307+
CmpThreeWayInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
308+
CmpOrdering ordering, int64_t lt, int64_t eq,
309+
int64_t gt, std::optional<int64_t> unordered) {
310+
// The presense of unordered must match the value of ordering.
311+
if (ordering == CmpOrdering::Strong && unordered) {
312+
emitError() << "strong ordering does not include unordered ordering";
313+
return failure();
314+
}
315+
if (ordering == CmpOrdering::Partial && !unordered) {
316+
emitError() << "partial ordering lacks unordered ordering";
317+
return failure();
318+
}
319+
320+
return success();
321+
}
322+
323+
270324
//===----------------------------------------------------------------------===//
271325
// ConstComplexAttr definitions
272326
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
7878
os << dynCastInfoAttr.getAlias();
7979
return AliasResult::FinalAlias;
8080
}
81+
if (auto cmpThreeWayInfoAttr =
82+
mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) {
83+
os << cmpThreeWayInfoAttr.getAlias();
84+
return AliasResult::FinalAlias;
85+
}
8186
return AliasResult::NoAlias;
8287
}
8388
};
@@ -1132,6 +1137,21 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
11321137
return nullptr;
11331138
}
11341139

1140+
1141+
//===----------------------------------------------------------------------===//
1142+
// CmpThreeWayOp
1143+
//===----------------------------------------------------------------------===//
1144+
1145+
mlir::LogicalResult CmpThreeWayOp::verify() {
1146+
// Type of the result must be a signed integer type.
1147+
if (!getType().isSigned()) {
1148+
emitOpError() << "result type of cir.cmp3way must be a signed integer type";
1149+
return failure();
1150+
}
1151+
1152+
return success();
1153+
}
1154+
11351155
//===----------------------------------------------------------------------===//
11361156
// CaseOp
11371157
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)