-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[CIR] Upstream ThreeWayCmpOp #169963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CIR] Upstream ThreeWayCmpOp #169963
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -958,6 +958,71 @@ def CIR_ScopeOp : CIR_Op<"scope", [ | |||||||||||||||||
| let hasLLVMLowering = false; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||
| // CmpThreeWayOp | ||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||
|
|
||||||||||||||||||
| def CIR_CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> { | ||||||||||||||||||
| let summary = "Compare two values with C++ three-way comparison semantics"; | ||||||||||||||||||
| let description = [{ | ||||||||||||||||||
| The `cir.cmp3way` operation models the `<=>` operator in C++20. It takes two | ||||||||||||||||||
| operands with the same type and produces a result indicating the ordering | ||||||||||||||||||
| between the two input operands. | ||||||||||||||||||
|
|
||||||||||||||||||
| The result of the operation is a signed integer that indicates the ordering | ||||||||||||||||||
| between the two input operands. | ||||||||||||||||||
|
|
||||||||||||||||||
| There are two kinds of ordering: strong ordering and partial ordering. | ||||||||||||||||||
| Comparing different types of values yields different kinds of orderings. | ||||||||||||||||||
| The `info` parameter gives the ordering kind and other necessary information | ||||||||||||||||||
| about the comparison. | ||||||||||||||||||
|
|
||||||||||||||||||
| Example: | ||||||||||||||||||
|
|
||||||||||||||||||
| ```mlir | ||||||||||||||||||
| !s32i = !cir.int<s, 32> | ||||||||||||||||||
|
|
||||||||||||||||||
| #cmp3way_strong = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1> | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| #cmp3way_partial = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1, unordered = 2> | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| %0 = cir.const #cir.int<0> : !s32i | ||||||||||||||||||
| %1 = cir.const #cir.int<1> : !s32i | ||||||||||||||||||
| %2 = cir.cmp3way(%0 : !s32i, %1, #cmp3way_strong) : !s8i | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| %3 = cir.const #cir.fp<0.0> : !cir.float | ||||||||||||||||||
| %4 = cir.const #cir.fp<1.0> : !cir.float | ||||||||||||||||||
| %5 = cir.cmp3way(%3 : !cir.float, %4, #cmp3way_partial) : !s8i | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| ``` | ||||||||||||||||||
| }]; | ||||||||||||||||||
|
|
||||||||||||||||||
| let arguments = (ins | ||||||||||||||||||
| CIR_AnyType:$lhs, | ||||||||||||||||||
| CIR_AnyType:$rhs, | ||||||||||||||||||
| CIR_CmpThreeWayInfoAttr:$info | ||||||||||||||||||
| ); | ||||||||||||||||||
|
|
||||||||||||||||||
| let results = (outs CIR_AnySIntType:$result); | ||||||||||||||||||
|
|
||||||||||||||||||
| let assemblyFormat = [{ | ||||||||||||||||||
| `(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)` | ||||||||||||||||||
| `:` type($result) attr-dict | ||||||||||||||||||
| }]; | ||||||||||||||||||
|
Comment on lines
+1006
to
+1009
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I don't think this was consistent with our latest ASM style guidelines. I'm not 100% sure about my suggestion. @xlauko Can you help here? |
||||||||||||||||||
|
|
||||||||||||||||||
| let extraClassDeclaration = [{ | ||||||||||||||||||
| /// Determine whether this three-way comparison produces a strong ordering. | ||||||||||||||||||
| bool isStrongOrdering() { | ||||||||||||||||||
| return getInfo().getOrdering() == cir::CmpOrdering::Strong; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /// Determine whether this three-way comparison compares integral operands. | ||||||||||||||||||
| bool isIntegralComparison() { | ||||||||||||||||||
| return mlir::isa<cir::IntType>(getLhs().getType()); | ||||||||||||||||||
| } | ||||||||||||||||||
| }]; | ||||||||||||||||||
|
|
||||||||||||||||||
| let hasVerifier = 1; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||
| // SwitchOp | ||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -567,6 +567,39 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { | |||||
| return cir::StackRestoreOp::create(*this, loc, v); | ||||||
| } | ||||||
|
|
||||||
| cir::CmpThreeWayOp createThreeWayCmpStrong(mlir::Location loc, | ||||||
| mlir::Value lhs, mlir::Value rhs, | ||||||
| const llvm::APSInt <Res, | ||||||
| const llvm::APSInt &eqRes, | ||||||
| const llvm::APSInt >Res) { | ||||||
| assert(ltRes.getBitWidth() == eqRes.getBitWidth() && | ||||||
| ltRes.getBitWidth() == gtRes.getBitWidth() && | ||||||
| "the three comparison results must have the same bit width"); | ||||||
| cir::IntType cmpResultTy = getSIntNTy(ltRes.getBitWidth()); | ||||||
| auto infoAttr = cir::CmpThreeWayInfoAttr::get( | ||||||
| getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), | ||||||
| gtRes.getSExtValue()); | ||||||
| return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs, | ||||||
| infoAttr); | ||||||
| } | ||||||
|
|
||||||
| cir::CmpThreeWayOp | ||||||
| createThreeWayCmpPartial(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, | ||||||
| const llvm::APSInt <Res, const llvm::APSInt &eqRes, | ||||||
| const llvm::APSInt >Res, | ||||||
| const llvm::APSInt &unorderedRes) { | ||||||
| assert(ltRes.getBitWidth() == eqRes.getBitWidth() && | ||||||
| ltRes.getBitWidth() == gtRes.getBitWidth() && | ||||||
| ltRes.getBitWidth() == unorderedRes.getBitWidth() && | ||||||
| "the four comparison results must have the same bit width"); | ||||||
| auto cmpResultTy = getSIntNTy(ltRes.getBitWidth()); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| auto infoAttr = cir::CmpThreeWayInfoAttr::get( | ||||||
| getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), | ||||||
| gtRes.getSExtValue(), unorderedRes.getSExtValue()); | ||||||
| return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs, | ||||||
| infoAttr); | ||||||
| } | ||||||
|
|
||||||
| mlir::Value createSetBitfield(mlir::Location loc, mlir::Type resultType, | ||||||
| Address dstAddr, mlir::Type storageType, | ||||||
| mlir::Value src, const CIRGenBitFieldInfo &info, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |||||
| #include "clang/AST/Expr.h" | ||||||
| #include "clang/AST/RecordLayout.h" | ||||||
| #include "clang/AST/StmtVisitor.h" | ||||||
| #include "llvm/IR/Value.h" | ||||||
| #include <cstdint> | ||||||
|
|
||||||
| using namespace clang; | ||||||
|
|
@@ -298,8 +299,63 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> { | |||||
| Visit(e->getRHS()); | ||||||
| } | ||||||
| void VisitBinCmp(const BinaryOperator *e) { | ||||||
| cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp"); | ||||||
| assert(cgf.getContext().hasSameType(e->getLHS()->getType(), | ||||||
| e->getRHS()->getType())); | ||||||
| const ComparisonCategoryInfo &cmpInfo = | ||||||
| cgf.getContext().CompCategories.getInfoForType(e->getType()); | ||||||
| assert(cmpInfo.Record->isTriviallyCopyable() && | ||||||
| "cannot copy non-trivially copyable aggregate"); | ||||||
|
|
||||||
| QualType argTy = e->getLHS()->getType(); | ||||||
|
|
||||||
| if (!argTy->isIntegralOrEnumerationType() && !argTy->isRealFloatingType() && | ||||||
| !argTy->isNullPtrType() && !argTy->isPointerType() && | ||||||
| !argTy->isMemberPointerType() && !argTy->isAnyComplexType()) | ||||||
| llvm_unreachable("aggregate three-way comparison"); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| mlir::Location loc = cgf.getLoc(e->getSourceRange()); | ||||||
| CIRGenBuilderTy builder = cgf.getBuilder(); | ||||||
|
|
||||||
| if (e->getType()->isAnyComplexType()) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there also be a diagnostic issued for aggregate types here? I'm not sure we handle that case correctly. The unreachable on line 314 doesn't seem to cover it. |
||||||
| llvm_unreachable("NYI"); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| mlir::Value lhs = cgf.emitAnyExpr(e->getLHS()).getValue(); | ||||||
| mlir::Value rhs = cgf.emitAnyExpr(e->getRHS()).getValue(); | ||||||
|
|
||||||
| mlir::Value resultScalar; | ||||||
| if (argTy->isNullPtrType()) { | ||||||
| resultScalar = | ||||||
| builder.getConstInt(loc, cmpInfo.getEqualOrEquiv()->getIntValue()); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this case is covered in your test. |
||||||
| } else { | ||||||
| llvm::APSInt ltRes = cmpInfo.getLess()->getIntValue(); | ||||||
| llvm::APSInt eqRes = cmpInfo.getEqualOrEquiv()->getIntValue(); | ||||||
| llvm::APSInt gtRes = cmpInfo.getGreater()->getIntValue(); | ||||||
| if (!cmpInfo.isPartial()) { | ||||||
| // Strong ordering. | ||||||
| resultScalar = | ||||||
| builder.createThreeWayCmpStrong(loc, lhs, rhs, ltRes, eqRes, gtRes); | ||||||
| } else { | ||||||
| // Partial ordering. | ||||||
| llvm::APSInt unorderedRes = cmpInfo.getUnordered()->getIntValue(); | ||||||
| resultScalar = builder.createThreeWayCmpPartial( | ||||||
| loc, lhs, rhs, ltRes, eqRes, gtRes, unorderedRes); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Create the return value in the destination slot. | ||||||
| ensureDest(loc, e->getType()); | ||||||
| LValue destLVal = cgf.makeAddrLValue(dest.getAddress(), e->getType()); | ||||||
|
|
||||||
| // Emit the address of the first (and only) field in the comparison category | ||||||
| // type, and initialize it from the constant integer value produced above. | ||||||
| const FieldDecl *resultField = *cmpInfo.Record->field_begin(); | ||||||
| LValue fieldLVal = cgf.emitLValueForFieldInitialization( | ||||||
| destLVal, resultField, resultField->getName()); | ||||||
| cgf.emitStoreThroughLValue(RValue::get(resultScalar), fieldLVal); | ||||||
|
|
||||||
| // All done! The result is in the dest slot. | ||||||
| } | ||||||
|
|
||||||
| void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) { | ||||||
| cgf.cgm.errorNYI(e->getSourceRange(), | ||||||
| "AggExprEmitter: VisitCXXRewrittenBinaryOperator"); | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -267,6 +267,58 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| //===----------------------------------------------------------------------===// | ||||||
| // CmpThreeWayInfoAttr definitions | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| std::string CmpThreeWayInfoAttr::getAlias() const { | ||||||
| std::string alias = "cmp3way_info"; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Maybe this to make it more compact? The examples and tests would need to be updated accordingly. |
||||||
|
|
||||||
| if (getOrdering() == CmpOrdering::Strong) | ||||||
| alias.append("_strong_"); | ||||||
| else | ||||||
| alias.append("_partial_"); | ||||||
|
|
||||||
| auto appendInt = [&](int64_t value) { | ||||||
| if (value < 0) { | ||||||
| alias.push_back('n'); | ||||||
| value = -value; | ||||||
| } | ||||||
| alias.append(std::to_string(value)); | ||||||
| }; | ||||||
|
|
||||||
| alias.append("lt"); | ||||||
| appendInt(getLt()); | ||||||
| alias.append("eq"); | ||||||
| appendInt(getEq()); | ||||||
| alias.append("gt"); | ||||||
| appendInt(getGt()); | ||||||
|
|
||||||
| if (std::optional<int> unordered = getUnordered()) { | ||||||
| alias.append("un"); | ||||||
| appendInt(unordered.value()); | ||||||
| } | ||||||
|
|
||||||
| return alias; | ||||||
| } | ||||||
|
|
||||||
| LogicalResult | ||||||
| CmpThreeWayInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError, | ||||||
| CmpOrdering ordering, int64_t lt, int64_t eq, | ||||||
| int64_t gt, std::optional<int64_t> unordered) { | ||||||
| // The presense of unordered must match the value of ordering. | ||||||
| if (ordering == CmpOrdering::Strong && unordered) { | ||||||
| emitError() << "strong ordering does not include unordered ordering"; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add an invalid IR test that verifies these errors? |
||||||
| return failure(); | ||||||
| } | ||||||
| if (ordering == CmpOrdering::Partial && !unordered) { | ||||||
| emitError() << "partial ordering lacks unordered ordering"; | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| //===----------------------------------------------------------------------===// | ||||||
| // ConstComplexAttr definitions | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,11 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface { | |
| os << dynCastInfoAttr.getAlias(); | ||
| return AliasResult::FinalAlias; | ||
| } | ||
| if (auto cmpThreeWayInfoAttr = | ||
| mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) { | ||
| os << cmpThreeWayInfoAttr.getAlias(); | ||
| return AliasResult::FinalAlias; | ||
| } | ||
| return AliasResult::NoAlias; | ||
| } | ||
| }; | ||
|
|
@@ -1132,6 +1137,20 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { | |
| return nullptr; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // CmpThreeWayOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| mlir::LogicalResult CmpThreeWayOp::verify() { | ||
| // Type of the result must be a signed integer type. | ||
| if (!getType().isSigned()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to verify this. The MLIR constraints should handle it automatically. |
||
| emitOpError() << "result type of cir.cmp3way must be a signed integer type"; | ||
| return failure(); | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // CaseOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example to show the expected format?