Skip to content

Conversation

@jurahul
Copy link
Contributor

@jurahul jurahul commented May 5, 2025

Adopt TrailingObjects convenience API that was added in #138970 in LLVM and MLIR code.

@github-actions
Copy link

github-actions bot commented May 5, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@jurahul jurahul force-pushed the trailing_objects branch 3 times, most recently from 2e68377 to b229281 Compare May 5, 2025 23:14
@jurahul jurahul marked this pull request as ready for review May 7, 2025 13:51
@jurahul jurahul requested a review from nikic May 7, 2025 13:51
@jurahul jurahul requested review from fmayer and jpienaar May 7, 2025 13:51
@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-tablegen
@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-debuginfo

Author: Rahul Joshi (jurahul)

Changes

Add a specialization of getTrailingObjects() for a single trailing type. This is a common case and with the specialization you don't need to specify the single trailing type redundantly. Also add an overload for getTrailingObjects which takes size and returns an ArryaRef/MutableArrayRef as that's a common use case as well.

Fix some uses of TrailingObjects class to use private inheritance and remove unneeded numTrailingObjects function when there is a single trailing type.


Patch is 27.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138554.diff

15 Files Affected:

  • (modified) llvm/include/llvm/DebugInfo/BTF/BTF.h (+2-4)
  • (modified) llvm/include/llvm/IR/DataLayout.h (+5-7)
  • (modified) llvm/include/llvm/Support/TrailingObjects.h (+33)
  • (modified) llvm/include/llvm/TableGen/Record.h (+7-13)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+1-1)
  • (modified) llvm/lib/IR/AttributeImpl.h (+9-17)
  • (modified) llvm/lib/IR/Attributes.cpp (+2-2)
  • (modified) llvm/lib/Support/TrieRawHashMap.cpp (+2-2)
  • (modified) llvm/lib/Transforms/IPO/LowerTypeTests.cpp (+4-8)
  • (modified) llvm/unittests/Support/TrailingObjectsTest.cpp (+2-4)
  • (modified) mlir/include/mlir/IR/Operation.h (+2-3)
  • (modified) mlir/include/mlir/Tools/PDLL/AST/Nodes.h (+24-28)
  • (modified) mlir/lib/IR/AffineMapDetail.h (+5-3)
  • (modified) mlir/lib/IR/Location.cpp (+6-5)
  • (modified) mlir/lib/IR/TypeDetail.h (+4-5)
diff --git a/llvm/include/llvm/DebugInfo/BTF/BTF.h b/llvm/include/llvm/DebugInfo/BTF/BTF.h
index d88af2ff30bdf..bd666e4202985 100644
--- a/llvm/include/llvm/DebugInfo/BTF/BTF.h
+++ b/llvm/include/llvm/DebugInfo/BTF/BTF.h
@@ -304,14 +304,12 @@ enum PatchableRelocKind : uint32_t {
 // For CommonType sub-types that are followed by a single entry of
 // some type in the binary format.
 #define BTF_DEFINE_TAIL(Type, Accessor)                                        \
-  const Type &Accessor() const { return *getTrailingObjects<Type>(); }
+  const Type &Accessor() const { return *getTrailingObjects(); }
 
 // For CommonType sub-types that are followed by CommonType::getVlen()
 // number of entries of some type in the binary format.
 #define BTF_DEFINE_TAIL_ARR(Type, Accessor)                                    \
-  ArrayRef<Type> Accessor() const {                                            \
-    return ArrayRef<Type>(getTrailingObjects<Type>(), getVlen());              \
-  }
+  ArrayRef<Type> Accessor() const { return getTrailingObjects(getVlen()); }
 
 struct ArrayType final : CommonType,
                          private TrailingObjects<ArrayType, BTFArray> {
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index 2ad080e6d0cd2..d83fe1299237b 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -564,7 +564,9 @@ inline LLVMTargetDataRef wrap(const DataLayout *P) {
 
 /// Used to lazily calculate structure layout information for a target machine,
 /// based on the DataLayout structure.
-class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
+class StructLayout final : private TrailingObjects<StructLayout, TypeSize> {
+  friend TrailingObjects;
+
   TypeSize StructSize;
   Align StructAlignment;
   unsigned IsPadded : 1;
@@ -586,11 +588,11 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   unsigned getElementContainingOffset(uint64_t FixedOffset) const;
 
   MutableArrayRef<TypeSize> getMemberOffsets() {
-    return llvm::MutableArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   ArrayRef<TypeSize> getMemberOffsets() const {
-    return llvm::ArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   TypeSize getElementOffset(unsigned Idx) const {
@@ -606,10 +608,6 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   friend class DataLayout; // Only DataLayout can create this class
 
   StructLayout(StructType *ST, const DataLayout &DL);
-
-  size_t numTrailingObjects(OverloadToken<TypeSize>) const {
-    return NumElements;
-  }
 };
 
 // The implementation of this method is provided inline as it is particularly
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index 07cf08df45a6a..c6d274cd4301d 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -46,6 +46,7 @@
 #ifndef LLVM_SUPPORT_TRAILINGOBJECTS_H
 #define LLVM_SUPPORT_TRAILINGOBJECTS_H
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -301,6 +302,38 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
         static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>());
   }
 
+  // getTrailingObjects() specialization for a single trailing type.
+  using FirstTrailingType =
+      typename std::tuple_element_t<0, std::tuple<TrailingTys...>>;
+
+  const FirstTrailingType *getTrailingObjects() const {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  FirstTrailingType *getTrailingObjects() {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  // Functions that return the trailing objects as ArrayRefs.
+  template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
+    return {getTrailingObjects<T>(), N};
+  }
+  template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects<T>(), N};
+  }
+  MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
+    return {getTrailingObjects(), N};
+  }
+  ArrayRef<FirstTrailingType> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects(), N};
+  }
+
   /// Returns the size of the trailing data, if an object were
   /// allocated with the given counts (The counts are in the same order
   /// as the template arguments). This does not include the size of the
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 982cc255553a2..d5a138924953e 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -258,7 +258,7 @@ class RecordRecTy final : public RecTy,
   void Profile(FoldingSetNodeID &ID) const;
 
   ArrayRef<const Record *> getClasses() const {
-    return ArrayRef(getTrailingObjects<const Record *>(), NumClasses);
+    return getTrailingObjects(NumClasses);
   }
 
   using const_record_iterator = const Record *const *;
@@ -632,9 +632,7 @@ class BitsInit final : public TypedInit,
 
   const Init *resolveReferences(Resolver &R) const override;
 
-  ArrayRef<const Init *> getBits() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumBits);
-  }
+  ArrayRef<const Init *> getBits() const { return getTrailingObjects(NumBits); }
 
   const Init *getBit(unsigned Bit) const override { return getBits()[Bit]; }
 };
@@ -1026,10 +1024,6 @@ class CondOpInit final : public TypedInit,
   CondOpInit(ArrayRef<const Init *> Conds, ArrayRef<const Init *> Values,
              const RecTy *Type);
 
-  size_t numTrailingObjects(OverloadToken<Init *>) const {
-    return 2*NumConds;
-  }
-
 public:
   CondOpInit(const CondOpInit &) = delete;
   CondOpInit &operator=(const CondOpInit &) = delete;
@@ -1053,11 +1047,11 @@ class CondOpInit final : public TypedInit,
   const Init *getVal(unsigned Num) const { return getVals()[Num]; }
 
   ArrayRef<const Init *> getConds() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumConds);
+    return getTrailingObjects(NumConds);
   }
 
   ArrayRef<const Init *> getVals() const {
-    return ArrayRef(getTrailingObjects<const Init *>() + NumConds, NumConds);
+    return ArrayRef(getTrailingObjects() + NumConds, NumConds);
   }
 
   const Init *Fold(const Record *CurRec) const;
@@ -1375,7 +1369,7 @@ class VarDefInit final
   bool           args_empty() const { return NumArgs == 0; }
 
   ArrayRef<const ArgumentInit *> args() const {
-    return ArrayRef(getTrailingObjects<const ArgumentInit *>(), NumArgs);
+    return getTrailingObjects(NumArgs);
   }
 
   const Init *getBit(unsigned Bit) const override {
@@ -1488,11 +1482,11 @@ class DagInit final
   }
 
   ArrayRef<const Init *> getArgs() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumArgs);
+    return getTrailingObjects<const Init *>(NumArgs);
   }
 
   ArrayRef<const StringInit *> getArgNames() const {
-    return ArrayRef(getTrailingObjects<const StringInit *>(), NumArgs);
+    return getTrailingObjects<const StringInit *>(NumArgs);
   }
 
   const Init *resolveReferences(Resolver &R) const override;
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 4074ed65885c7..103502ae3e66f 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -560,7 +560,7 @@ class BitcodeConstant final : public Value,
   static bool classof(const Value *V) { return V->getValueID() == SubclassID; }
 
   ArrayRef<unsigned> getOperandIDs() const {
-    return ArrayRef(getTrailingObjects<unsigned>(), NumOperands);
+    return ArrayRef(getTrailingObjects(), NumOperands);
   }
 
   std::optional<ConstantRange> getInRange() const {
diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h
index 59cc489ade40d..7dcc4fd2213e5 100644
--- a/llvm/lib/IR/AttributeImpl.h
+++ b/llvm/lib/IR/AttributeImpl.h
@@ -195,15 +195,12 @@ class StringAttributeImpl final
 
   unsigned KindSize;
   unsigned ValSize;
-  size_t numTrailingObjects(OverloadToken<char>) const {
-    return KindSize + 1 + ValSize + 1;
-  }
 
 public:
   StringAttributeImpl(StringRef Kind, StringRef Val = StringRef())
       : AttributeImpl(StringAttrEntry), KindSize(Kind.size()),
         ValSize(Val.size()) {
-    char *TrailingString = getTrailingObjects<char>();
+    char *TrailingString = getTrailingObjects();
     // Some users rely on zero-termination.
     llvm::copy(Kind, TrailingString);
     TrailingString[KindSize] = '\0';
@@ -212,10 +209,10 @@ class StringAttributeImpl final
   }
 
   StringRef getStringKind() const {
-    return StringRef(getTrailingObjects<char>(), KindSize);
+    return StringRef(getTrailingObjects(), KindSize);
   }
   StringRef getStringValue() const {
-    return StringRef(getTrailingObjects<char>() + KindSize + 1, ValSize);
+    return StringRef(getTrailingObjects() + KindSize + 1, ValSize);
   }
 
   static size_t totalSizeToAlloc(StringRef Kind, StringRef Val) {
@@ -250,25 +247,23 @@ class ConstantRangeListAttributeImpl final
   friend TrailingObjects;
 
   unsigned Size;
-  size_t numTrailingObjects(OverloadToken<ConstantRange>) const { return Size; }
 
 public:
   ConstantRangeListAttributeImpl(Attribute::AttrKind Kind,
                                  ArrayRef<ConstantRange> Val)
       : EnumAttributeImpl(ConstantRangeListAttrEntry, Kind), Size(Val.size()) {
     assert(Size > 0);
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
+    ConstantRange *TrailingCR = getTrailingObjects();
     std::uninitialized_copy(Val.begin(), Val.end(), TrailingCR);
   }
 
   ~ConstantRangeListAttributeImpl() {
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
-    for (unsigned I = 0; I != Size; ++I)
-      TrailingCR[I].~ConstantRange();
+    for (ConstantRange &CR : getTrailingObjects(Size))
+      CR.~ConstantRange();
   }
 
   ArrayRef<ConstantRange> getConstantRangeListValue() const {
-    return ArrayRef(getTrailingObjects<ConstantRange>(), Size);
+    return getTrailingObjects(Size);
   }
 
   static size_t totalSizeToAlloc(ArrayRef<ConstantRange> Val) {
@@ -353,7 +348,7 @@ class AttributeSetNode final
 
   using iterator = const Attribute *;
 
-  iterator begin() const { return getTrailingObjects<Attribute>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrs; }
 
   void Profile(FoldingSetNodeID &ID) const {
@@ -383,9 +378,6 @@ class AttributeListImpl final
   /// Union of enum attributes available at any index.
   AttributeBitSet AvailableSomewhereAttrs;
 
-  // Helper fn for TrailingObjects class.
-  size_t numTrailingObjects(OverloadToken<AttributeSet>) { return NumAttrSets; }
-
 public:
   AttributeListImpl(ArrayRef<AttributeSet> Sets);
 
@@ -407,7 +399,7 @@ class AttributeListImpl final
 
   using iterator = const AttributeSet *;
 
-  iterator begin() const { return getTrailingObjects<AttributeSet>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrSets; }
 
   void Profile(FoldingSetNodeID &ID) const;
diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp
index 33ac8bfaf4e7c..5b0ceb34381a9 100644
--- a/llvm/lib/IR/Attributes.cpp
+++ b/llvm/lib/IR/Attributes.cpp
@@ -1237,7 +1237,7 @@ LLVM_DUMP_METHOD void AttributeSet::dump() const {
 AttributeSetNode::AttributeSetNode(ArrayRef<Attribute> Attrs)
     : NumAttrs(Attrs.size()) {
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Attrs, getTrailingObjects<Attribute>());
+  llvm::copy(Attrs, getTrailingObjects());
 
   for (const auto &I : *this) {
     if (I.isStringAttribute())
@@ -1423,7 +1423,7 @@ AttributeListImpl::AttributeListImpl(ArrayRef<AttributeSet> Sets)
   assert(!Sets.empty() && "pointless AttributeListImpl");
 
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Sets, getTrailingObjects<AttributeSet>());
+  llvm::copy(Sets, getTrailingObjects());
 
   // Initialize AvailableFunctionAttrs and AvailableSomewhereAttrs
   // summary bitsets.
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 11d79a62d011d..bb779fe87ae62 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -62,7 +62,7 @@ class TrieSubtrie final
 public:
   using Slot = LazyAtomicPointer<TrieNode>;
 
-  Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
+  Slot &get(size_t I) { return getTrailingObjects()[I]; }
   TrieNode *load(size_t I) { return get(I).load(); }
 
   unsigned size() const { return Size; }
@@ -190,7 +190,7 @@ class ThreadSafeTrieRawHashMapBase::ImplType final
   }
 
   // Get the root which is the trailing object.
-  TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
+  TrieSubtrie *getRoot() { return getTrailingObjects(); }
 
   static void *operator new(size_t Size) { return ::operator new(Size); }
   void operator delete(void *Ptr) { ::operator delete(Ptr); }
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index c93568943e833..aa6672a57f9a5 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -285,8 +285,6 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
   // module and its jumptable entry needs to be exported to thinlto backends.
   bool IsExported;
 
-  size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
-
 public:
   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
                                   bool IsJumpTableCanonical, bool IsExported,
@@ -298,7 +296,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     GTM->IsJumpTableCanonical = IsJumpTableCanonical;
     GTM->IsExported = IsExported;
     std::uninitialized_copy(Types.begin(), Types.end(),
-                            GTM->getTrailingObjects<MDNode *>());
+                            GTM->getTrailingObjects());
     return GTM;
   }
 
@@ -314,9 +312,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     return IsExported;
   }
 
-  ArrayRef<MDNode *> types() const {
-    return ArrayRef(getTrailingObjects<MDNode *>(), NTypes);
-  }
+  ArrayRef<MDNode *> types() const { return getTrailingObjects(NTypes); }
 };
 
 struct ICallBranchFunnel final
@@ -331,13 +327,13 @@ struct ICallBranchFunnel final
     Call->UniqueId = UniqueId;
     Call->NTargets = Targets.size();
     std::uninitialized_copy(Targets.begin(), Targets.end(),
-                            Call->getTrailingObjects<GlobalTypeMember *>());
+                            Call->getTrailingObjects());
     return Call;
   }
 
   CallInst *CI;
   ArrayRef<GlobalTypeMember *> targets() const {
-    return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
+    return getTrailingObjects(NTargets);
   }
 
   unsigned UniqueId;
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index e2656b2229ca6..6e4447cd9d670 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -21,11 +21,9 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   unsigned NumShorts;
 
 protected:
-  size_t numTrailingObjects(OverloadToken<short>) const { return NumShorts; }
-
   Class1(int *ShortArray, unsigned NumShorts) : NumShorts(NumShorts) {
     std::uninitialized_copy(ShortArray, ShortArray + NumShorts,
-                            getTrailingObjects<short>());
+                            getTrailingObjects());
   }
 
 public:
@@ -35,7 +33,7 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   }
   void operator delete(void *p) { ::operator delete(p); }
 
-  short get(unsigned Num) const { return getTrailingObjects<short>()[Num]; }
+  short get(unsigned Num) const { return getTrailingObjects()[Num]; }
 
   unsigned numShorts() const { return NumShorts; }
 
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 95d944170732e..68ab1527b480a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -679,8 +679,7 @@ class alignas(8) Operation final
     if (numRegions == 0)
       return MutableArrayRef<Region>();
 
-    auto *regions = getTrailingObjects<Region>();
-    return {regions, numRegions};
+    return getTrailingObjects<Region>(numRegions);
   }
 
   /// Returns the region held by this operation at position 'index'.
@@ -694,7 +693,7 @@ class alignas(8) Operation final
   //===--------------------------------------------------------------------===//
 
   MutableArrayRef<BlockOperand> getBlockOperands() {
-    return {getTrailingObjects<BlockOperand>(), numSuccs};
+    return getTrailingObjects<BlockOperand>(numSuccs);
   }
 
   // Successor iteration.
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index f174ac2f476f6..9ad94839890b7 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -183,10 +183,10 @@ class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
 
   /// Return the children of this compound statement.
   MutableArrayRef<Stmt *> getChildren() {
-    return {getTrailingObjects<Stmt *>(), numChildren};
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *> getChildren() const {
-    return const_cast<CompoundStmt *>(this)->getChildren();
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
@@ -275,10 +275,10 @@ class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
 
   /// Return the replacement values of this statement.
   MutableArrayRef<Expr *> getReplExprs() {
-    return {getTrailingObjects<Expr *>(), numReplExprs};
+    return getTrailingObjects(numReplExprs);
   }
   ArrayRef<Expr *> getReplExprs() const {
-    return const_cast<ReplaceStmt *>(this)->getReplExprs();
+    return getTrailingObjects(numReplExprs);
   }
 
 private:
@@ -400,12 +400,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
   Expr *getCallableExpr() const { return callable; }
 
   /// Return the arguments of this call.
-  MutableArrayRef<Expr *> getArguments() {
-    return {getTrailingObjects<Expr *>(), numArgs};
-  }
-  ArrayRef<Expr *> getArguments() const {
-    return const_cast<CallExpr *>(this)->getArguments();
-  }
+  MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
+  ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
 
   /// Returns whether the result of this call is to be negated.
   bool getIsNegated() const { return isNegated; }
@@ -534,10 +530,10 @@ class OperationExpr final
 
   /// Return the operands of this operation.
   MutableArrayRef<Expr *> getOperands() {
-    return {getTrailingObjects<Expr *>(), numOperands};
+    return getTrailingObjects<Expr *>(numOperands);
   }
   ArrayRef<Expr *> getOperands() const {
-    return const_cast<OperationExpr *>(this)->getOperands();
+    return getTrailingObjects<Expr *>(numOperands);
   }
 
   /// Return the result types of this operation.
@@ -550,10 +546,10 @@ class OperationExpr final
 
   /// Return the attributes of this operation.
   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
-    return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
+    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
   }
-  MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
-    return const_cast<OperationExpr *>(this)->getAttributes();
+  ArrayRef<NamedAttributeDecl *> getAtt...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Rahul Joshi (jurahul)

Changes

Add a specialization of getTrailingObjects() for a single trailing type. This is a common case and with the specialization you don't need to specify the single trailing type redundantly. Also add an overload for getTrailingObjects which takes size and returns an ArryaRef/MutableArrayRef as that's a common use case as well.

Fix some uses of TrailingObjects class to use private inheritance and remove unneeded numTrailingObjects function when there is a single trailing type.


Patch is 27.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138554.diff

15 Files Affected:

  • (modified) llvm/include/llvm/DebugInfo/BTF/BTF.h (+2-4)
  • (modified) llvm/include/llvm/IR/DataLayout.h (+5-7)
  • (modified) llvm/include/llvm/Support/TrailingObjects.h (+33)
  • (modified) llvm/include/llvm/TableGen/Record.h (+7-13)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+1-1)
  • (modified) llvm/lib/IR/AttributeImpl.h (+9-17)
  • (modified) llvm/lib/IR/Attributes.cpp (+2-2)
  • (modified) llvm/lib/Support/TrieRawHashMap.cpp (+2-2)
  • (modified) llvm/lib/Transforms/IPO/LowerTypeTests.cpp (+4-8)
  • (modified) llvm/unittests/Support/TrailingObjectsTest.cpp (+2-4)
  • (modified) mlir/include/mlir/IR/Operation.h (+2-3)
  • (modified) mlir/include/mlir/Tools/PDLL/AST/Nodes.h (+24-28)
  • (modified) mlir/lib/IR/AffineMapDetail.h (+5-3)
  • (modified) mlir/lib/IR/Location.cpp (+6-5)
  • (modified) mlir/lib/IR/TypeDetail.h (+4-5)
diff --git a/llvm/include/llvm/DebugInfo/BTF/BTF.h b/llvm/include/llvm/DebugInfo/BTF/BTF.h
index d88af2ff30bdf..bd666e4202985 100644
--- a/llvm/include/llvm/DebugInfo/BTF/BTF.h
+++ b/llvm/include/llvm/DebugInfo/BTF/BTF.h
@@ -304,14 +304,12 @@ enum PatchableRelocKind : uint32_t {
 // For CommonType sub-types that are followed by a single entry of
 // some type in the binary format.
 #define BTF_DEFINE_TAIL(Type, Accessor)                                        \
-  const Type &Accessor() const { return *getTrailingObjects<Type>(); }
+  const Type &Accessor() const { return *getTrailingObjects(); }
 
 // For CommonType sub-types that are followed by CommonType::getVlen()
 // number of entries of some type in the binary format.
 #define BTF_DEFINE_TAIL_ARR(Type, Accessor)                                    \
-  ArrayRef<Type> Accessor() const {                                            \
-    return ArrayRef<Type>(getTrailingObjects<Type>(), getVlen());              \
-  }
+  ArrayRef<Type> Accessor() const { return getTrailingObjects(getVlen()); }
 
 struct ArrayType final : CommonType,
                          private TrailingObjects<ArrayType, BTFArray> {
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index 2ad080e6d0cd2..d83fe1299237b 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -564,7 +564,9 @@ inline LLVMTargetDataRef wrap(const DataLayout *P) {
 
 /// Used to lazily calculate structure layout information for a target machine,
 /// based on the DataLayout structure.
-class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
+class StructLayout final : private TrailingObjects<StructLayout, TypeSize> {
+  friend TrailingObjects;
+
   TypeSize StructSize;
   Align StructAlignment;
   unsigned IsPadded : 1;
@@ -586,11 +588,11 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   unsigned getElementContainingOffset(uint64_t FixedOffset) const;
 
   MutableArrayRef<TypeSize> getMemberOffsets() {
-    return llvm::MutableArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   ArrayRef<TypeSize> getMemberOffsets() const {
-    return llvm::ArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   TypeSize getElementOffset(unsigned Idx) const {
@@ -606,10 +608,6 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   friend class DataLayout; // Only DataLayout can create this class
 
   StructLayout(StructType *ST, const DataLayout &DL);
-
-  size_t numTrailingObjects(OverloadToken<TypeSize>) const {
-    return NumElements;
-  }
 };
 
 // The implementation of this method is provided inline as it is particularly
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index 07cf08df45a6a..c6d274cd4301d 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -46,6 +46,7 @@
 #ifndef LLVM_SUPPORT_TRAILINGOBJECTS_H
 #define LLVM_SUPPORT_TRAILINGOBJECTS_H
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -301,6 +302,38 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
         static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>());
   }
 
+  // getTrailingObjects() specialization for a single trailing type.
+  using FirstTrailingType =
+      typename std::tuple_element_t<0, std::tuple<TrailingTys...>>;
+
+  const FirstTrailingType *getTrailingObjects() const {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  FirstTrailingType *getTrailingObjects() {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  // Functions that return the trailing objects as ArrayRefs.
+  template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
+    return {getTrailingObjects<T>(), N};
+  }
+  template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects<T>(), N};
+  }
+  MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
+    return {getTrailingObjects(), N};
+  }
+  ArrayRef<FirstTrailingType> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects(), N};
+  }
+
   /// Returns the size of the trailing data, if an object were
   /// allocated with the given counts (The counts are in the same order
   /// as the template arguments). This does not include the size of the
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 982cc255553a2..d5a138924953e 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -258,7 +258,7 @@ class RecordRecTy final : public RecTy,
   void Profile(FoldingSetNodeID &ID) const;
 
   ArrayRef<const Record *> getClasses() const {
-    return ArrayRef(getTrailingObjects<const Record *>(), NumClasses);
+    return getTrailingObjects(NumClasses);
   }
 
   using const_record_iterator = const Record *const *;
@@ -632,9 +632,7 @@ class BitsInit final : public TypedInit,
 
   const Init *resolveReferences(Resolver &R) const override;
 
-  ArrayRef<const Init *> getBits() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumBits);
-  }
+  ArrayRef<const Init *> getBits() const { return getTrailingObjects(NumBits); }
 
   const Init *getBit(unsigned Bit) const override { return getBits()[Bit]; }
 };
@@ -1026,10 +1024,6 @@ class CondOpInit final : public TypedInit,
   CondOpInit(ArrayRef<const Init *> Conds, ArrayRef<const Init *> Values,
              const RecTy *Type);
 
-  size_t numTrailingObjects(OverloadToken<Init *>) const {
-    return 2*NumConds;
-  }
-
 public:
   CondOpInit(const CondOpInit &) = delete;
   CondOpInit &operator=(const CondOpInit &) = delete;
@@ -1053,11 +1047,11 @@ class CondOpInit final : public TypedInit,
   const Init *getVal(unsigned Num) const { return getVals()[Num]; }
 
   ArrayRef<const Init *> getConds() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumConds);
+    return getTrailingObjects(NumConds);
   }
 
   ArrayRef<const Init *> getVals() const {
-    return ArrayRef(getTrailingObjects<const Init *>() + NumConds, NumConds);
+    return ArrayRef(getTrailingObjects() + NumConds, NumConds);
   }
 
   const Init *Fold(const Record *CurRec) const;
@@ -1375,7 +1369,7 @@ class VarDefInit final
   bool           args_empty() const { return NumArgs == 0; }
 
   ArrayRef<const ArgumentInit *> args() const {
-    return ArrayRef(getTrailingObjects<const ArgumentInit *>(), NumArgs);
+    return getTrailingObjects(NumArgs);
   }
 
   const Init *getBit(unsigned Bit) const override {
@@ -1488,11 +1482,11 @@ class DagInit final
   }
 
   ArrayRef<const Init *> getArgs() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumArgs);
+    return getTrailingObjects<const Init *>(NumArgs);
   }
 
   ArrayRef<const StringInit *> getArgNames() const {
-    return ArrayRef(getTrailingObjects<const StringInit *>(), NumArgs);
+    return getTrailingObjects<const StringInit *>(NumArgs);
   }
 
   const Init *resolveReferences(Resolver &R) const override;
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 4074ed65885c7..103502ae3e66f 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -560,7 +560,7 @@ class BitcodeConstant final : public Value,
   static bool classof(const Value *V) { return V->getValueID() == SubclassID; }
 
   ArrayRef<unsigned> getOperandIDs() const {
-    return ArrayRef(getTrailingObjects<unsigned>(), NumOperands);
+    return ArrayRef(getTrailingObjects(), NumOperands);
   }
 
   std::optional<ConstantRange> getInRange() const {
diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h
index 59cc489ade40d..7dcc4fd2213e5 100644
--- a/llvm/lib/IR/AttributeImpl.h
+++ b/llvm/lib/IR/AttributeImpl.h
@@ -195,15 +195,12 @@ class StringAttributeImpl final
 
   unsigned KindSize;
   unsigned ValSize;
-  size_t numTrailingObjects(OverloadToken<char>) const {
-    return KindSize + 1 + ValSize + 1;
-  }
 
 public:
   StringAttributeImpl(StringRef Kind, StringRef Val = StringRef())
       : AttributeImpl(StringAttrEntry), KindSize(Kind.size()),
         ValSize(Val.size()) {
-    char *TrailingString = getTrailingObjects<char>();
+    char *TrailingString = getTrailingObjects();
     // Some users rely on zero-termination.
     llvm::copy(Kind, TrailingString);
     TrailingString[KindSize] = '\0';
@@ -212,10 +209,10 @@ class StringAttributeImpl final
   }
 
   StringRef getStringKind() const {
-    return StringRef(getTrailingObjects<char>(), KindSize);
+    return StringRef(getTrailingObjects(), KindSize);
   }
   StringRef getStringValue() const {
-    return StringRef(getTrailingObjects<char>() + KindSize + 1, ValSize);
+    return StringRef(getTrailingObjects() + KindSize + 1, ValSize);
   }
 
   static size_t totalSizeToAlloc(StringRef Kind, StringRef Val) {
@@ -250,25 +247,23 @@ class ConstantRangeListAttributeImpl final
   friend TrailingObjects;
 
   unsigned Size;
-  size_t numTrailingObjects(OverloadToken<ConstantRange>) const { return Size; }
 
 public:
   ConstantRangeListAttributeImpl(Attribute::AttrKind Kind,
                                  ArrayRef<ConstantRange> Val)
       : EnumAttributeImpl(ConstantRangeListAttrEntry, Kind), Size(Val.size()) {
     assert(Size > 0);
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
+    ConstantRange *TrailingCR = getTrailingObjects();
     std::uninitialized_copy(Val.begin(), Val.end(), TrailingCR);
   }
 
   ~ConstantRangeListAttributeImpl() {
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
-    for (unsigned I = 0; I != Size; ++I)
-      TrailingCR[I].~ConstantRange();
+    for (ConstantRange &CR : getTrailingObjects(Size))
+      CR.~ConstantRange();
   }
 
   ArrayRef<ConstantRange> getConstantRangeListValue() const {
-    return ArrayRef(getTrailingObjects<ConstantRange>(), Size);
+    return getTrailingObjects(Size);
   }
 
   static size_t totalSizeToAlloc(ArrayRef<ConstantRange> Val) {
@@ -353,7 +348,7 @@ class AttributeSetNode final
 
   using iterator = const Attribute *;
 
-  iterator begin() const { return getTrailingObjects<Attribute>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrs; }
 
   void Profile(FoldingSetNodeID &ID) const {
@@ -383,9 +378,6 @@ class AttributeListImpl final
   /// Union of enum attributes available at any index.
   AttributeBitSet AvailableSomewhereAttrs;
 
-  // Helper fn for TrailingObjects class.
-  size_t numTrailingObjects(OverloadToken<AttributeSet>) { return NumAttrSets; }
-
 public:
   AttributeListImpl(ArrayRef<AttributeSet> Sets);
 
@@ -407,7 +399,7 @@ class AttributeListImpl final
 
   using iterator = const AttributeSet *;
 
-  iterator begin() const { return getTrailingObjects<AttributeSet>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrSets; }
 
   void Profile(FoldingSetNodeID &ID) const;
diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp
index 33ac8bfaf4e7c..5b0ceb34381a9 100644
--- a/llvm/lib/IR/Attributes.cpp
+++ b/llvm/lib/IR/Attributes.cpp
@@ -1237,7 +1237,7 @@ LLVM_DUMP_METHOD void AttributeSet::dump() const {
 AttributeSetNode::AttributeSetNode(ArrayRef<Attribute> Attrs)
     : NumAttrs(Attrs.size()) {
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Attrs, getTrailingObjects<Attribute>());
+  llvm::copy(Attrs, getTrailingObjects());
 
   for (const auto &I : *this) {
     if (I.isStringAttribute())
@@ -1423,7 +1423,7 @@ AttributeListImpl::AttributeListImpl(ArrayRef<AttributeSet> Sets)
   assert(!Sets.empty() && "pointless AttributeListImpl");
 
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Sets, getTrailingObjects<AttributeSet>());
+  llvm::copy(Sets, getTrailingObjects());
 
   // Initialize AvailableFunctionAttrs and AvailableSomewhereAttrs
   // summary bitsets.
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 11d79a62d011d..bb779fe87ae62 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -62,7 +62,7 @@ class TrieSubtrie final
 public:
   using Slot = LazyAtomicPointer<TrieNode>;
 
-  Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
+  Slot &get(size_t I) { return getTrailingObjects()[I]; }
   TrieNode *load(size_t I) { return get(I).load(); }
 
   unsigned size() const { return Size; }
@@ -190,7 +190,7 @@ class ThreadSafeTrieRawHashMapBase::ImplType final
   }
 
   // Get the root which is the trailing object.
-  TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
+  TrieSubtrie *getRoot() { return getTrailingObjects(); }
 
   static void *operator new(size_t Size) { return ::operator new(Size); }
   void operator delete(void *Ptr) { ::operator delete(Ptr); }
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index c93568943e833..aa6672a57f9a5 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -285,8 +285,6 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
   // module and its jumptable entry needs to be exported to thinlto backends.
   bool IsExported;
 
-  size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
-
 public:
   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
                                   bool IsJumpTableCanonical, bool IsExported,
@@ -298,7 +296,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     GTM->IsJumpTableCanonical = IsJumpTableCanonical;
     GTM->IsExported = IsExported;
     std::uninitialized_copy(Types.begin(), Types.end(),
-                            GTM->getTrailingObjects<MDNode *>());
+                            GTM->getTrailingObjects());
     return GTM;
   }
 
@@ -314,9 +312,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     return IsExported;
   }
 
-  ArrayRef<MDNode *> types() const {
-    return ArrayRef(getTrailingObjects<MDNode *>(), NTypes);
-  }
+  ArrayRef<MDNode *> types() const { return getTrailingObjects(NTypes); }
 };
 
 struct ICallBranchFunnel final
@@ -331,13 +327,13 @@ struct ICallBranchFunnel final
     Call->UniqueId = UniqueId;
     Call->NTargets = Targets.size();
     std::uninitialized_copy(Targets.begin(), Targets.end(),
-                            Call->getTrailingObjects<GlobalTypeMember *>());
+                            Call->getTrailingObjects());
     return Call;
   }
 
   CallInst *CI;
   ArrayRef<GlobalTypeMember *> targets() const {
-    return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
+    return getTrailingObjects(NTargets);
   }
 
   unsigned UniqueId;
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index e2656b2229ca6..6e4447cd9d670 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -21,11 +21,9 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   unsigned NumShorts;
 
 protected:
-  size_t numTrailingObjects(OverloadToken<short>) const { return NumShorts; }
-
   Class1(int *ShortArray, unsigned NumShorts) : NumShorts(NumShorts) {
     std::uninitialized_copy(ShortArray, ShortArray + NumShorts,
-                            getTrailingObjects<short>());
+                            getTrailingObjects());
   }
 
 public:
@@ -35,7 +33,7 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   }
   void operator delete(void *p) { ::operator delete(p); }
 
-  short get(unsigned Num) const { return getTrailingObjects<short>()[Num]; }
+  short get(unsigned Num) const { return getTrailingObjects()[Num]; }
 
   unsigned numShorts() const { return NumShorts; }
 
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 95d944170732e..68ab1527b480a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -679,8 +679,7 @@ class alignas(8) Operation final
     if (numRegions == 0)
       return MutableArrayRef<Region>();
 
-    auto *regions = getTrailingObjects<Region>();
-    return {regions, numRegions};
+    return getTrailingObjects<Region>(numRegions);
   }
 
   /// Returns the region held by this operation at position 'index'.
@@ -694,7 +693,7 @@ class alignas(8) Operation final
   //===--------------------------------------------------------------------===//
 
   MutableArrayRef<BlockOperand> getBlockOperands() {
-    return {getTrailingObjects<BlockOperand>(), numSuccs};
+    return getTrailingObjects<BlockOperand>(numSuccs);
   }
 
   // Successor iteration.
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index f174ac2f476f6..9ad94839890b7 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -183,10 +183,10 @@ class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
 
   /// Return the children of this compound statement.
   MutableArrayRef<Stmt *> getChildren() {
-    return {getTrailingObjects<Stmt *>(), numChildren};
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *> getChildren() const {
-    return const_cast<CompoundStmt *>(this)->getChildren();
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
@@ -275,10 +275,10 @@ class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
 
   /// Return the replacement values of this statement.
   MutableArrayRef<Expr *> getReplExprs() {
-    return {getTrailingObjects<Expr *>(), numReplExprs};
+    return getTrailingObjects(numReplExprs);
   }
   ArrayRef<Expr *> getReplExprs() const {
-    return const_cast<ReplaceStmt *>(this)->getReplExprs();
+    return getTrailingObjects(numReplExprs);
   }
 
 private:
@@ -400,12 +400,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
   Expr *getCallableExpr() const { return callable; }
 
   /// Return the arguments of this call.
-  MutableArrayRef<Expr *> getArguments() {
-    return {getTrailingObjects<Expr *>(), numArgs};
-  }
-  ArrayRef<Expr *> getArguments() const {
-    return const_cast<CallExpr *>(this)->getArguments();
-  }
+  MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
+  ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
 
   /// Returns whether the result of this call is to be negated.
   bool getIsNegated() const { return isNegated; }
@@ -534,10 +530,10 @@ class OperationExpr final
 
   /// Return the operands of this operation.
   MutableArrayRef<Expr *> getOperands() {
-    return {getTrailingObjects<Expr *>(), numOperands};
+    return getTrailingObjects<Expr *>(numOperands);
   }
   ArrayRef<Expr *> getOperands() const {
-    return const_cast<OperationExpr *>(this)->getOperands();
+    return getTrailingObjects<Expr *>(numOperands);
   }
 
   /// Return the result types of this operation.
@@ -550,10 +546,10 @@ class OperationExpr final
 
   /// Return the attributes of this operation.
   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
-    return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
+    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
   }
-  MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
-    return const_cast<OperationExpr *>(this)->getAttributes();
+  ArrayRef<NamedAttributeDecl *> getAtt...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

Author: Rahul Joshi (jurahul)

Changes

Add a specialization of getTrailingObjects() for a single trailing type. This is a common case and with the specialization you don't need to specify the single trailing type redundantly. Also add an overload for getTrailingObjects which takes size and returns an ArryaRef/MutableArrayRef as that's a common use case as well.

Fix some uses of TrailingObjects class to use private inheritance and remove unneeded numTrailingObjects function when there is a single trailing type.


Patch is 27.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138554.diff

15 Files Affected:

  • (modified) llvm/include/llvm/DebugInfo/BTF/BTF.h (+2-4)
  • (modified) llvm/include/llvm/IR/DataLayout.h (+5-7)
  • (modified) llvm/include/llvm/Support/TrailingObjects.h (+33)
  • (modified) llvm/include/llvm/TableGen/Record.h (+7-13)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+1-1)
  • (modified) llvm/lib/IR/AttributeImpl.h (+9-17)
  • (modified) llvm/lib/IR/Attributes.cpp (+2-2)
  • (modified) llvm/lib/Support/TrieRawHashMap.cpp (+2-2)
  • (modified) llvm/lib/Transforms/IPO/LowerTypeTests.cpp (+4-8)
  • (modified) llvm/unittests/Support/TrailingObjectsTest.cpp (+2-4)
  • (modified) mlir/include/mlir/IR/Operation.h (+2-3)
  • (modified) mlir/include/mlir/Tools/PDLL/AST/Nodes.h (+24-28)
  • (modified) mlir/lib/IR/AffineMapDetail.h (+5-3)
  • (modified) mlir/lib/IR/Location.cpp (+6-5)
  • (modified) mlir/lib/IR/TypeDetail.h (+4-5)
diff --git a/llvm/include/llvm/DebugInfo/BTF/BTF.h b/llvm/include/llvm/DebugInfo/BTF/BTF.h
index d88af2ff30bdf..bd666e4202985 100644
--- a/llvm/include/llvm/DebugInfo/BTF/BTF.h
+++ b/llvm/include/llvm/DebugInfo/BTF/BTF.h
@@ -304,14 +304,12 @@ enum PatchableRelocKind : uint32_t {
 // For CommonType sub-types that are followed by a single entry of
 // some type in the binary format.
 #define BTF_DEFINE_TAIL(Type, Accessor)                                        \
-  const Type &Accessor() const { return *getTrailingObjects<Type>(); }
+  const Type &Accessor() const { return *getTrailingObjects(); }
 
 // For CommonType sub-types that are followed by CommonType::getVlen()
 // number of entries of some type in the binary format.
 #define BTF_DEFINE_TAIL_ARR(Type, Accessor)                                    \
-  ArrayRef<Type> Accessor() const {                                            \
-    return ArrayRef<Type>(getTrailingObjects<Type>(), getVlen());              \
-  }
+  ArrayRef<Type> Accessor() const { return getTrailingObjects(getVlen()); }
 
 struct ArrayType final : CommonType,
                          private TrailingObjects<ArrayType, BTFArray> {
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index 2ad080e6d0cd2..d83fe1299237b 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -564,7 +564,9 @@ inline LLVMTargetDataRef wrap(const DataLayout *P) {
 
 /// Used to lazily calculate structure layout information for a target machine,
 /// based on the DataLayout structure.
-class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
+class StructLayout final : private TrailingObjects<StructLayout, TypeSize> {
+  friend TrailingObjects;
+
   TypeSize StructSize;
   Align StructAlignment;
   unsigned IsPadded : 1;
@@ -586,11 +588,11 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   unsigned getElementContainingOffset(uint64_t FixedOffset) const;
 
   MutableArrayRef<TypeSize> getMemberOffsets() {
-    return llvm::MutableArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   ArrayRef<TypeSize> getMemberOffsets() const {
-    return llvm::ArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   TypeSize getElementOffset(unsigned Idx) const {
@@ -606,10 +608,6 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   friend class DataLayout; // Only DataLayout can create this class
 
   StructLayout(StructType *ST, const DataLayout &DL);
-
-  size_t numTrailingObjects(OverloadToken<TypeSize>) const {
-    return NumElements;
-  }
 };
 
 // The implementation of this method is provided inline as it is particularly
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index 07cf08df45a6a..c6d274cd4301d 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -46,6 +46,7 @@
 #ifndef LLVM_SUPPORT_TRAILINGOBJECTS_H
 #define LLVM_SUPPORT_TRAILINGOBJECTS_H
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -301,6 +302,38 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
         static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>());
   }
 
+  // getTrailingObjects() specialization for a single trailing type.
+  using FirstTrailingType =
+      typename std::tuple_element_t<0, std::tuple<TrailingTys...>>;
+
+  const FirstTrailingType *getTrailingObjects() const {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  FirstTrailingType *getTrailingObjects() {
+    static_assert(sizeof...(TrailingTys) == 1,
+                  "Can use non-templated getTrailingObjects() only when there "
+                  "is a single trailing type");
+    return getTrailingObjects<FirstTrailingType>();
+  }
+
+  // Functions that return the trailing objects as ArrayRefs.
+  template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
+    return {getTrailingObjects<T>(), N};
+  }
+  template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects<T>(), N};
+  }
+  MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
+    return {getTrailingObjects(), N};
+  }
+  ArrayRef<FirstTrailingType> getTrailingObjects(size_t N) const {
+    return {getTrailingObjects(), N};
+  }
+
   /// Returns the size of the trailing data, if an object were
   /// allocated with the given counts (The counts are in the same order
   /// as the template arguments). This does not include the size of the
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 982cc255553a2..d5a138924953e 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -258,7 +258,7 @@ class RecordRecTy final : public RecTy,
   void Profile(FoldingSetNodeID &ID) const;
 
   ArrayRef<const Record *> getClasses() const {
-    return ArrayRef(getTrailingObjects<const Record *>(), NumClasses);
+    return getTrailingObjects(NumClasses);
   }
 
   using const_record_iterator = const Record *const *;
@@ -632,9 +632,7 @@ class BitsInit final : public TypedInit,
 
   const Init *resolveReferences(Resolver &R) const override;
 
-  ArrayRef<const Init *> getBits() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumBits);
-  }
+  ArrayRef<const Init *> getBits() const { return getTrailingObjects(NumBits); }
 
   const Init *getBit(unsigned Bit) const override { return getBits()[Bit]; }
 };
@@ -1026,10 +1024,6 @@ class CondOpInit final : public TypedInit,
   CondOpInit(ArrayRef<const Init *> Conds, ArrayRef<const Init *> Values,
              const RecTy *Type);
 
-  size_t numTrailingObjects(OverloadToken<Init *>) const {
-    return 2*NumConds;
-  }
-
 public:
   CondOpInit(const CondOpInit &) = delete;
   CondOpInit &operator=(const CondOpInit &) = delete;
@@ -1053,11 +1047,11 @@ class CondOpInit final : public TypedInit,
   const Init *getVal(unsigned Num) const { return getVals()[Num]; }
 
   ArrayRef<const Init *> getConds() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumConds);
+    return getTrailingObjects(NumConds);
   }
 
   ArrayRef<const Init *> getVals() const {
-    return ArrayRef(getTrailingObjects<const Init *>() + NumConds, NumConds);
+    return ArrayRef(getTrailingObjects() + NumConds, NumConds);
   }
 
   const Init *Fold(const Record *CurRec) const;
@@ -1375,7 +1369,7 @@ class VarDefInit final
   bool           args_empty() const { return NumArgs == 0; }
 
   ArrayRef<const ArgumentInit *> args() const {
-    return ArrayRef(getTrailingObjects<const ArgumentInit *>(), NumArgs);
+    return getTrailingObjects(NumArgs);
   }
 
   const Init *getBit(unsigned Bit) const override {
@@ -1488,11 +1482,11 @@ class DagInit final
   }
 
   ArrayRef<const Init *> getArgs() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumArgs);
+    return getTrailingObjects<const Init *>(NumArgs);
   }
 
   ArrayRef<const StringInit *> getArgNames() const {
-    return ArrayRef(getTrailingObjects<const StringInit *>(), NumArgs);
+    return getTrailingObjects<const StringInit *>(NumArgs);
   }
 
   const Init *resolveReferences(Resolver &R) const override;
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 4074ed65885c7..103502ae3e66f 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -560,7 +560,7 @@ class BitcodeConstant final : public Value,
   static bool classof(const Value *V) { return V->getValueID() == SubclassID; }
 
   ArrayRef<unsigned> getOperandIDs() const {
-    return ArrayRef(getTrailingObjects<unsigned>(), NumOperands);
+    return ArrayRef(getTrailingObjects(), NumOperands);
   }
 
   std::optional<ConstantRange> getInRange() const {
diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h
index 59cc489ade40d..7dcc4fd2213e5 100644
--- a/llvm/lib/IR/AttributeImpl.h
+++ b/llvm/lib/IR/AttributeImpl.h
@@ -195,15 +195,12 @@ class StringAttributeImpl final
 
   unsigned KindSize;
   unsigned ValSize;
-  size_t numTrailingObjects(OverloadToken<char>) const {
-    return KindSize + 1 + ValSize + 1;
-  }
 
 public:
   StringAttributeImpl(StringRef Kind, StringRef Val = StringRef())
       : AttributeImpl(StringAttrEntry), KindSize(Kind.size()),
         ValSize(Val.size()) {
-    char *TrailingString = getTrailingObjects<char>();
+    char *TrailingString = getTrailingObjects();
     // Some users rely on zero-termination.
     llvm::copy(Kind, TrailingString);
     TrailingString[KindSize] = '\0';
@@ -212,10 +209,10 @@ class StringAttributeImpl final
   }
 
   StringRef getStringKind() const {
-    return StringRef(getTrailingObjects<char>(), KindSize);
+    return StringRef(getTrailingObjects(), KindSize);
   }
   StringRef getStringValue() const {
-    return StringRef(getTrailingObjects<char>() + KindSize + 1, ValSize);
+    return StringRef(getTrailingObjects() + KindSize + 1, ValSize);
   }
 
   static size_t totalSizeToAlloc(StringRef Kind, StringRef Val) {
@@ -250,25 +247,23 @@ class ConstantRangeListAttributeImpl final
   friend TrailingObjects;
 
   unsigned Size;
-  size_t numTrailingObjects(OverloadToken<ConstantRange>) const { return Size; }
 
 public:
   ConstantRangeListAttributeImpl(Attribute::AttrKind Kind,
                                  ArrayRef<ConstantRange> Val)
       : EnumAttributeImpl(ConstantRangeListAttrEntry, Kind), Size(Val.size()) {
     assert(Size > 0);
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
+    ConstantRange *TrailingCR = getTrailingObjects();
     std::uninitialized_copy(Val.begin(), Val.end(), TrailingCR);
   }
 
   ~ConstantRangeListAttributeImpl() {
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
-    for (unsigned I = 0; I != Size; ++I)
-      TrailingCR[I].~ConstantRange();
+    for (ConstantRange &CR : getTrailingObjects(Size))
+      CR.~ConstantRange();
   }
 
   ArrayRef<ConstantRange> getConstantRangeListValue() const {
-    return ArrayRef(getTrailingObjects<ConstantRange>(), Size);
+    return getTrailingObjects(Size);
   }
 
   static size_t totalSizeToAlloc(ArrayRef<ConstantRange> Val) {
@@ -353,7 +348,7 @@ class AttributeSetNode final
 
   using iterator = const Attribute *;
 
-  iterator begin() const { return getTrailingObjects<Attribute>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrs; }
 
   void Profile(FoldingSetNodeID &ID) const {
@@ -383,9 +378,6 @@ class AttributeListImpl final
   /// Union of enum attributes available at any index.
   AttributeBitSet AvailableSomewhereAttrs;
 
-  // Helper fn for TrailingObjects class.
-  size_t numTrailingObjects(OverloadToken<AttributeSet>) { return NumAttrSets; }
-
 public:
   AttributeListImpl(ArrayRef<AttributeSet> Sets);
 
@@ -407,7 +399,7 @@ class AttributeListImpl final
 
   using iterator = const AttributeSet *;
 
-  iterator begin() const { return getTrailingObjects<AttributeSet>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrSets; }
 
   void Profile(FoldingSetNodeID &ID) const;
diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp
index 33ac8bfaf4e7c..5b0ceb34381a9 100644
--- a/llvm/lib/IR/Attributes.cpp
+++ b/llvm/lib/IR/Attributes.cpp
@@ -1237,7 +1237,7 @@ LLVM_DUMP_METHOD void AttributeSet::dump() const {
 AttributeSetNode::AttributeSetNode(ArrayRef<Attribute> Attrs)
     : NumAttrs(Attrs.size()) {
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Attrs, getTrailingObjects<Attribute>());
+  llvm::copy(Attrs, getTrailingObjects());
 
   for (const auto &I : *this) {
     if (I.isStringAttribute())
@@ -1423,7 +1423,7 @@ AttributeListImpl::AttributeListImpl(ArrayRef<AttributeSet> Sets)
   assert(!Sets.empty() && "pointless AttributeListImpl");
 
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Sets, getTrailingObjects<AttributeSet>());
+  llvm::copy(Sets, getTrailingObjects());
 
   // Initialize AvailableFunctionAttrs and AvailableSomewhereAttrs
   // summary bitsets.
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 11d79a62d011d..bb779fe87ae62 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -62,7 +62,7 @@ class TrieSubtrie final
 public:
   using Slot = LazyAtomicPointer<TrieNode>;
 
-  Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
+  Slot &get(size_t I) { return getTrailingObjects()[I]; }
   TrieNode *load(size_t I) { return get(I).load(); }
 
   unsigned size() const { return Size; }
@@ -190,7 +190,7 @@ class ThreadSafeTrieRawHashMapBase::ImplType final
   }
 
   // Get the root which is the trailing object.
-  TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
+  TrieSubtrie *getRoot() { return getTrailingObjects(); }
 
   static void *operator new(size_t Size) { return ::operator new(Size); }
   void operator delete(void *Ptr) { ::operator delete(Ptr); }
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index c93568943e833..aa6672a57f9a5 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -285,8 +285,6 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
   // module and its jumptable entry needs to be exported to thinlto backends.
   bool IsExported;
 
-  size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
-
 public:
   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
                                   bool IsJumpTableCanonical, bool IsExported,
@@ -298,7 +296,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     GTM->IsJumpTableCanonical = IsJumpTableCanonical;
     GTM->IsExported = IsExported;
     std::uninitialized_copy(Types.begin(), Types.end(),
-                            GTM->getTrailingObjects<MDNode *>());
+                            GTM->getTrailingObjects());
     return GTM;
   }
 
@@ -314,9 +312,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     return IsExported;
   }
 
-  ArrayRef<MDNode *> types() const {
-    return ArrayRef(getTrailingObjects<MDNode *>(), NTypes);
-  }
+  ArrayRef<MDNode *> types() const { return getTrailingObjects(NTypes); }
 };
 
 struct ICallBranchFunnel final
@@ -331,13 +327,13 @@ struct ICallBranchFunnel final
     Call->UniqueId = UniqueId;
     Call->NTargets = Targets.size();
     std::uninitialized_copy(Targets.begin(), Targets.end(),
-                            Call->getTrailingObjects<GlobalTypeMember *>());
+                            Call->getTrailingObjects());
     return Call;
   }
 
   CallInst *CI;
   ArrayRef<GlobalTypeMember *> targets() const {
-    return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
+    return getTrailingObjects(NTargets);
   }
 
   unsigned UniqueId;
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index e2656b2229ca6..6e4447cd9d670 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -21,11 +21,9 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   unsigned NumShorts;
 
 protected:
-  size_t numTrailingObjects(OverloadToken<short>) const { return NumShorts; }
-
   Class1(int *ShortArray, unsigned NumShorts) : NumShorts(NumShorts) {
     std::uninitialized_copy(ShortArray, ShortArray + NumShorts,
-                            getTrailingObjects<short>());
+                            getTrailingObjects());
   }
 
 public:
@@ -35,7 +33,7 @@ class Class1 final : protected TrailingObjects<Class1, short> {
   }
   void operator delete(void *p) { ::operator delete(p); }
 
-  short get(unsigned Num) const { return getTrailingObjects<short>()[Num]; }
+  short get(unsigned Num) const { return getTrailingObjects()[Num]; }
 
   unsigned numShorts() const { return NumShorts; }
 
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 95d944170732e..68ab1527b480a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -679,8 +679,7 @@ class alignas(8) Operation final
     if (numRegions == 0)
       return MutableArrayRef<Region>();
 
-    auto *regions = getTrailingObjects<Region>();
-    return {regions, numRegions};
+    return getTrailingObjects<Region>(numRegions);
   }
 
   /// Returns the region held by this operation at position 'index'.
@@ -694,7 +693,7 @@ class alignas(8) Operation final
   //===--------------------------------------------------------------------===//
 
   MutableArrayRef<BlockOperand> getBlockOperands() {
-    return {getTrailingObjects<BlockOperand>(), numSuccs};
+    return getTrailingObjects<BlockOperand>(numSuccs);
   }
 
   // Successor iteration.
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index f174ac2f476f6..9ad94839890b7 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -183,10 +183,10 @@ class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
 
   /// Return the children of this compound statement.
   MutableArrayRef<Stmt *> getChildren() {
-    return {getTrailingObjects<Stmt *>(), numChildren};
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *> getChildren() const {
-    return const_cast<CompoundStmt *>(this)->getChildren();
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
@@ -275,10 +275,10 @@ class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
 
   /// Return the replacement values of this statement.
   MutableArrayRef<Expr *> getReplExprs() {
-    return {getTrailingObjects<Expr *>(), numReplExprs};
+    return getTrailingObjects(numReplExprs);
   }
   ArrayRef<Expr *> getReplExprs() const {
-    return const_cast<ReplaceStmt *>(this)->getReplExprs();
+    return getTrailingObjects(numReplExprs);
   }
 
 private:
@@ -400,12 +400,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
   Expr *getCallableExpr() const { return callable; }
 
   /// Return the arguments of this call.
-  MutableArrayRef<Expr *> getArguments() {
-    return {getTrailingObjects<Expr *>(), numArgs};
-  }
-  ArrayRef<Expr *> getArguments() const {
-    return const_cast<CallExpr *>(this)->getArguments();
-  }
+  MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
+  ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
 
   /// Returns whether the result of this call is to be negated.
   bool getIsNegated() const { return isNegated; }
@@ -534,10 +530,10 @@ class OperationExpr final
 
   /// Return the operands of this operation.
   MutableArrayRef<Expr *> getOperands() {
-    return {getTrailingObjects<Expr *>(), numOperands};
+    return getTrailingObjects<Expr *>(numOperands);
   }
   ArrayRef<Expr *> getOperands() const {
-    return const_cast<OperationExpr *>(this)->getOperands();
+    return getTrailingObjects<Expr *>(numOperands);
   }
 
   /// Return the result types of this operation.
@@ -550,10 +546,10 @@ class OperationExpr final
 
   /// Return the attributes of this operation.
   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
-    return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
+    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
   }
-  MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
-    return const_cast<OperationExpr *>(this)->getAttributes();
+  ArrayRef<NamedAttributeDecl *> getAtt...
[truncated]

@fmayer
Copy link
Contributor

fmayer commented May 7, 2025

I think this patch is [NFC], or at least [NFCI]

@joker-eph
Copy link
Collaborator

MLIR changes LGTM, but I second that this should be marked NFC (otherwise I'd expect tests)

@dwblaikie
Copy link
Collaborator

I'd expect this to be split up - into the Support API additions and tests (thanks for adding unit test coverage!) (which isn't exactly NFC - we treat the ADT/Support libraries somewhat as a functional library boundary that we unit test, etc) then the usage (which would be NFC - only changing how the code achieves the same goals, without adding/removing functionality from any command line tools)

@jurahul
Copy link
Contributor Author

jurahul commented May 7, 2025

I'd expect this to be split up - into the Support API additions and tests (thanks for adding unit test coverage!) (which isn't exactly NFC - we treat the ADT/Support libraries somewhat as a functional library boundary that we unit test, etc) then the usage (which would be NFC - only changing how the code achieves the same goals, without adding/removing functionality from any command line tools)

Sounds good. I'll create a "non NFC?" PR to just add the new API and unit tests, and then rebase this on top of that.

@jurahul
Copy link
Contributor Author

jurahul commented May 7, 2025

Changing this to draft while I extract out the API changes into a new PR,

@jurahul jurahul marked this pull request as draft May 7, 2025 18:40
@jurahul
Copy link
Contributor Author

jurahul commented May 8, 2025

Started #138970 for just the new API additions

@jurahul jurahul force-pushed the trailing_objects branch from b229281 to a1c9ea6 Compare May 9, 2025 19:33
@jurahul jurahul changed the title [LLVM][Support] Add getTrailingObjects() for single trailing type. [NFCI][LLVM/MLIR] Adopt TrailingObjects convienence API May 9, 2025
@jurahul jurahul changed the title [NFCI][LLVM/MLIR] Adopt TrailingObjects convienence API [NFCI][LLVM/MLIR] Adopt TrailingObjects convenience API May 9, 2025
@jurahul jurahul force-pushed the trailing_objects branch from a1c9ea6 to 87a3d26 Compare May 10, 2025 13:16
@jurahul jurahul marked this pull request as ready for review May 10, 2025 15:41
@jurahul jurahul requested review from dwblaikie, fmayer and joker-eph May 10, 2025 15:42
Adopt `TrailingObjects` convienence API that was added in
llvm#138970 in LLVM and MLIR code.
@jurahul jurahul merged commit 86c5112 into llvm:main May 12, 2025
12 checks passed
@jurahul jurahul deleted the trailing_objects branch May 12, 2025 23:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants