diff --git a/llvm/include/llvm/CodeGen/ValueTypes.td b/llvm/include/llvm/CodeGen/ValueTypes.td index 44edec98d20f3..9ea127dd15943 100644 --- a/llvm/include/llvm/CodeGen/ValueTypes.td +++ b/llvm/include/llvm/CodeGen/ValueTypes.td @@ -367,6 +367,10 @@ def aarch64mfp8 : ValueType<8, 253>; // 8-bit value in FPR (AArch64) def c64 : VTCheriCapability<64, 254>; // 64-bit CHERI capability value def c128 : VTCheriCapability<128, 255>; // 128-bit CHERI capability value +// Pseudo valuetype mapped to the current CHERI capability pointer size. +// Should only be used in TableGen. +def cPTR : VTAny<503>; + let isNormalValueType = false in { def token : ValueType<0, 504>; // TokenTy def MetadataVT : ValueType<0, 505> { // Metadata diff --git a/llvm/include/llvm/CodeGenTypes/MachineValueType.h b/llvm/include/llvm/CodeGenTypes/MachineValueType.h index 321fb6b601868..5c18c59ed592d 100644 --- a/llvm/include/llvm/CodeGenTypes/MachineValueType.h +++ b/llvm/include/llvm/CodeGenTypes/MachineValueType.h @@ -582,6 +582,12 @@ namespace llvm { MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE, force_iteration_on_noniterable_enum); } + + static auto cheri_capability_valuetypes() { + return enum_seq_inclusive(MVT::FIRST_CHERI_CAPABILITY_VALUETYPE, + MVT::LAST_CHERI_CAPABILITY_VALUETYPE, + force_iteration_on_noniterable_enum); + } /// @} }; diff --git a/llvm/test/TableGen/CPtrWildcard.td b/llvm/test/TableGen/CPtrWildcard.td new file mode 100644 index 0000000000000..96b51ae1044a3 --- /dev/null +++ b/llvm/test/TableGen/CPtrWildcard.td @@ -0,0 +1,74 @@ +// RUN: llvm-tblgen -gen-dag-isel -I %p/../../include %s -o - | FileCheck %s + +// Create an intrinsic that uses cPTR to overload on capability pointer types, +// and verify that we can match it correct in SelectionDAG. + +// CHECK: static const unsigned char MatcherTable[] = { +// CHECK-NEXT: /* 0*/ OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_WO_CHAIN), +// CHECK-NEXT:/* 3*/ OPC_CheckChild0Integer, 42, +// CHECK-NEXT:/* 5*/ OPC_RecordChild1, // #0 = $src +// CHECK-NEXT:/* 6*/ OPC_Scope, 9, /*->17*/ // 2 children in Scope +// CHECK-NEXT:/* 8*/ OPC_CheckChild1Type, /*MVT::c64*/126|128,1/*254*/, +// CHECK-NEXT:/* 11*/ OPC_MorphNodeTo1None, TARGET_VAL(MyTarget::C64_TO_I64), +// CHECK-NEXT: /*MVT::i64*/8, 1/*#Ops*/, 0, +// CHECK-NEXT: // Src: (intrinsic_wo_chain:{ *:[i64] } 21:{ *:[iPTR] }, c64:{ *:[c64] }:$src) - Complexity = 8 +// CHECK-NEXT: // Dst: (C64_TO_I64:{ *:[i64] } ?:{ *:[c64] }:$src) +// CHECK-NEXT:/* 17*/ /*Scope*/ 9, /*->27*/ +// CHECK-NEXT:/* 18*/ OPC_CheckChild1Type, /*MVT::c128*/127|128,1/*255*/, +// CHECK-NEXT:/* 21*/ OPC_MorphNodeTo1None, TARGET_VAL(MyTarget::C128_TO_I64), +// CHECK-NEXT: /*MVT::i64*/8, 1/*#Ops*/, 0, +// CHECK-NEXT: // Src: (intrinsic_wo_chain:{ *:[i64] } 21:{ *:[iPTR] }, c128:{ *:[c128] }:$src) - Complexity = 8 +// CHECK-NEXT: // Dst: (C128_TO_I64:{ *:[i64] } ?:{ *:[c128] }:$src) +// CHECK-NEXT:/* 27*/ 0, /*End of Scope*/ +// CHECK-NEXT: 0 +// CHECK-NEXT: }; // Total Array size is 29 bytes + +include "llvm/Target/Target.td" + +def my_cap_ty : LLVMQualPointerType<200> { + let VT = cPTR; +} + +def int_cap_get_length : + Intrinsic<[llvm_i64_ty], + [my_cap_ty], + [IntrNoMem, IntrWillReturn]>; + +class CapReg : Register { + let Namespace = "MyTarget"; +} + +def C64 : CapReg<"c0">; +def C64s + : RegisterClass<"MyTarget", [i64, c64], 64, + (add C64)>; + +def C128 : CapReg<"c0">; +def C128s + : RegisterClass<"MyTarget", [c128], 64, + (add C128)>; + +def C64_TO_I64 : Instruction { + let Namespace = "MyTarget"; + let OutOperandList = (outs C64s:$dst); + let InOperandList = (ins C64s:$src); +} + +def C128_TO_I64 : Instruction { + let Namespace = "MyTarget"; + let OutOperandList = (outs C64s:$dst); + let InOperandList = (ins C128s:$src); +} + +def : Pat< + (int_cap_get_length c64:$src), + (C64_TO_I64 $src) +>; + +def : Pat< + (int_cap_get_length c128:$src), + (C128_TO_I64 $src) +>; + +def MyTargetISA : InstrInfo; +def MyTarget : Target { let InstructionSet = MyTargetISA; } diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp index f1f7cd72ef9f2..5a8d3aa1036ba 100644 --- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp @@ -335,6 +335,8 @@ bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) { using WildPartT = std::pair>; static const WildPartT WildParts[] = { {MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }}, + {MVT::cPTR, + [](MVT T) { return T.isCheriCapability() || T == MVT::cPTR; }}, }; bool Changed = false; @@ -816,6 +818,10 @@ void TypeInfer::expandOverloads(TypeSetByHwMode::SetType &Out, if (Out.count(MVT::pAny)) { Out.erase(MVT::pAny); Out.insert(MVT::iPTR); + for (MVT T : MVT::cheri_capability_valuetypes()) { + if (Legal.count(T)) + Out.insert(MVT::cPTR); + } } else if (Out.count(MVT::iAny)) { Out.erase(MVT::iAny); for (MVT T : MVT::integer_valuetypes()) @@ -1647,9 +1653,11 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode &N, case SDTCisVT: // Operand must be a particular type. return NodeToApply.UpdateNodeType(ResNo, VVT, TP); - case SDTCisPtrTy: - // Operand must be same as target pointer type. - return NodeToApply.UpdateNodeType(ResNo, MVT::iPTR, TP); + case SDTCisPtrTy: { + // Operand must be a legal pointer (iPTR, or possibly cPTR) type. + const TypeSetByHwMode &PtrTys = TP.getDAGPatterns().getLegalPtrTypes(); + return NodeToApply.UpdateNodeType(ResNo, PtrTys, TP); + } case SDTCisInt: // Require it to be one of the legal integer VTs. return TI.EnforceInteger(NodeToApply.getExtType(ResNo)); @@ -3260,6 +3268,7 @@ CodeGenDAGPatterns::CodeGenDAGPatterns(const RecordKeeper &R, PatternRewriterFn PatternRewriter) : Records(R), Target(R), Intrinsics(R), LegalVTS(Target.getLegalValueTypes()), + LegalPtrVTS(ComputeLegalPtrTypes()), PatternRewriter(std::move(PatternRewriter)) { ParseNodeInfo(); ParseNodeTransforms(); @@ -3295,6 +3304,36 @@ const Record *CodeGenDAGPatterns::getSDNodeNamed(StringRef Name) const { return N; } +// Compute the subset of iPTR and cPTR legal for each mode, coalescing into the +// default mode where possible to avoid predicate explosion. +TypeSetByHwMode CodeGenDAGPatterns::ComputeLegalPtrTypes() const { + auto LegalPtrsForSet = [](const MachineValueTypeSet &In) { + MachineValueTypeSet Out; + Out.insert(MVT::iPTR); + for (MVT T : MVT::cheri_capability_valuetypes()) { + if (In.count(T)) { + Out.insert(MVT::cPTR); + break; + } + } + return Out; + }; + + const TypeSetByHwMode &LegalTypes = getLegalTypes(); + MachineValueTypeSet LegalPtrsDefault = + LegalPtrsForSet(LegalTypes.get(DefaultMode)); + + TypeSetByHwMode LegalPtrTypes; + for (const auto &I : LegalTypes) { + MachineValueTypeSet S = LegalPtrsForSet(I.second); + if (I.first != DefaultMode && S == LegalPtrsDefault) + continue; + LegalPtrTypes.getOrCreate(I.first).insert(S); + } + + return LegalPtrTypes; +} + // Parse all of the SDNode definitions for the target, populating SDNodes. void CodeGenDAGPatterns::ParseNodeInfo() { const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h index 64fec275faa68..2ed8d1376b045 100644 --- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h +++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h @@ -1135,6 +1135,7 @@ class CodeGenDAGPatterns { std::vector PatternsToMatch; TypeSetByHwMode LegalVTS; + TypeSetByHwMode LegalPtrVTS; using PatternRewriterFn = std::function; PatternRewriterFn PatternRewriter; @@ -1148,6 +1149,7 @@ class CodeGenDAGPatterns { CodeGenTarget &getTargetInfo() { return Target; } const CodeGenTarget &getTargetInfo() const { return Target; } const TypeSetByHwMode &getLegalTypes() const { return LegalVTS; } + const TypeSetByHwMode &getLegalPtrTypes() const { return LegalPtrVTS; } const Record *getSDNodeNamed(StringRef Name) const; @@ -1249,6 +1251,7 @@ class CodeGenDAGPatterns { } private: + TypeSetByHwMode ComputeLegalPtrTypes() const; void ParseNodeInfo(); void ParseNodeTransforms(); void ParseComplexPatterns(); diff --git a/llvm/utils/TableGen/Common/DAGISelMatcher.cpp b/llvm/utils/TableGen/Common/DAGISelMatcher.cpp index 255974624e8f0..4fdb386bf45e7 100644 --- a/llvm/utils/TableGen/Common/DAGISelMatcher.cpp +++ b/llvm/utils/TableGen/Common/DAGISelMatcher.cpp @@ -328,6 +328,14 @@ static bool TypesAreContradictory(MVT::SimpleValueType T1, if (T1 == T2) return false; + if (T1 == MVT::pAny) + return TypesAreContradictory(MVT::iPTR, T2) && + TypesAreContradictory(MVT::cPTR, T2); + + if (T2 == MVT::pAny) + return TypesAreContradictory(T1, MVT::iPTR) && + TypesAreContradictory(T1, MVT::cPTR); + // If either type is about iPtr, then they don't conflict unless the other // one is not a scalar integer type. if (T1 == MVT::iPTR) @@ -336,7 +344,13 @@ static bool TypesAreContradictory(MVT::SimpleValueType T1, if (T2 == MVT::iPTR) return !MVT(T1).isInteger() || MVT(T1).isVector(); - // Otherwise, they are two different non-iPTR types, they conflict. + if (T1 == MVT::cPTR) + return !MVT(T2).isCheriCapability() || MVT(T2).isVector(); + + if (T2 == MVT::cPTR) + return !MVT(T1).isCheriCapability() || MVT(T1).isVector(); + + // Otherwise, they are two different non-iPTR/cPTR types, they conflict. return true; } diff --git a/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp index 8c8d5d77ebd73..746726cf2510e 100644 --- a/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp +++ b/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp @@ -1426,7 +1426,9 @@ Error OperandMatcher::addTypeCheckPredicate(const TypeSetByHwMode &VTy, if (!VTy.isMachineValueType()) return failUnsupported("unsupported typeset"); - if (VTy.getMachineValueType() == MVT::iPTR && OperandIsAPointer) { + if ((VTy.getMachineValueType() == MVT::iPTR || + VTy.getMachineValueType() == MVT::cPTR) && + OperandIsAPointer) { addPredicate(0); return Error::success(); } diff --git a/llvm/utils/TableGen/DAGISelMatcherOpt.cpp b/llvm/utils/TableGen/DAGISelMatcherOpt.cpp index 8d8189983270e..268e6bbc4eee3 100644 --- a/llvm/utils/TableGen/DAGISelMatcherOpt.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherOpt.cpp @@ -519,9 +519,9 @@ static void FactorScope(std::unique_ptr &MatcherPtr) { CheckTypeMatcher *CTM = cast_or_null( FindNodeWithKind(Optn, Matcher::CheckType)); if (!CTM || - // iPTR checks could alias any other case without us knowing, don't - // bother with them. - CTM->getType() == MVT::iPTR || + // iPTR/cPTR checks could alias any other case without us knowing, + // don't bother with them. + CTM->getType() == MVT::iPTR || CTM->getType() == MVT::cPTR || // SwitchType only works for result #0. CTM->getResNo() != 0 || // If the CheckType isn't at the start of the list, see if we can move