Skip to content

Conversation

@IgWod-IMG
Copy link
Contributor

An alternative implementation could use ArrayRef of NamedAttributes or NamedAttrList to store structs decorations, as the deserializer uses NamedAttributes for decorations. However, using a custom struct allows us to store the spirv::Decorations directly rather than its name in a StringRef/StringAttr.

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

An alternative implementation could use ArrayRef of NamedAttributes or NamedAttrList to store structs decorations, as the deserializer uses NamedAttributes for decorations. However, using a custom struct allows us to store the spirv::Decorations directly rather than its name in a StringRef/StringAttr.


Patch is 38.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149793.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+37-2)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+56-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+93-25)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+23-11)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+24-4)
  • (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+6)
  • (modified) mlir/test/Target/SPIRV/memory-ops.mlir (+10-10)
  • (modified) mlir/test/Target/SPIRV/struct.mlir (+20-18)
  • (modified) mlir/test/Target/SPIRV/undef.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..56d09301345f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -327,10 +327,33 @@ class StructType
     }
   };
 
+  // Type for specifying the decoration(s) on the struct itself
+  struct StructDecorationInfo {
+    bool hasValue;
+    Decoration decoration;
+    Attribute decorationValue;
+
+    StructDecorationInfo(bool hasValue, Decoration decoration,
+                         Attribute decorationValue)
+        : hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const StructDecorationInfo &other) const {
+      return (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
+
+    bool operator<(const StructDecorationInfo &other) const {
+      return static_cast<uint32_t>(this->decoration) <
+             static_cast<uint32_t>(other.decoration);
+    }
+  };
+
   /// Construct a literal StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
                         ArrayRef<OffsetInfo> offsetInfo = {},
-                        ArrayRef<MemberDecorationInfo> memberDecorations = {});
+                        ArrayRef<MemberDecorationInfo> memberDecorations = {},
+                        ArrayRef<StructDecorationInfo> structDecorations = {});
 
   /// Construct an identified StructType. This creates a StructType whose body
   /// (member types, offset info, and decorations) is not set yet. A call to
@@ -364,6 +387,9 @@ class StructType
 
   bool hasOffset() const;
 
+  /// Returns true if the struct has a specified decoration.
+  bool hasDecoration(spirv::Decoration decoration) const;
+
   uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -377,12 +403,18 @@ class StructType
       unsigned i,
       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
 
+  // Returns in `structDecorations` the Decorations associated with the
+  // StructType.
+  void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+                                &structDecorations) const;
+
   /// Sets the contents of an incomplete identified StructType. This method must
   /// be called only for identified StructTypes and it must be called only once
   /// per instance. Otherwise, failure() is returned.
   LogicalResult
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
-             ArrayRef<MemberDecorationInfo> memberDecorations = {});
+             ArrayRef<MemberDecorationInfo> memberDecorations = {},
+             ArrayRef<StructDecorationInfo> structDecorations = {});
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      std::optional<StorageClass> storage = std::nullopt);
@@ -393,6 +425,9 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..6121fef7318bb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
 //             `!spirv.struct<` (id `,`)?
 //                          `(`
 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
-//                          `)>`
+//                          `)`
+//                            (`,` struct-decoration)?
+//                          `>`
 static Type parseStructType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
   // TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return Type();
   }
 
-  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+  if (failed(parser.parseRParen()))
+    return Type();
+
+  SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+  auto parseStructDecoration = [&]() {
+    std::optional<spirv::Decoration> decoration =
+        parseAndVerify<spirv::Decoration>(dialect, parser);
+    if (!decoration)
+      return failure();
+
+    // Parse decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      Attribute decorationValue;
+
+      if (failed(parser.parseAttribute(decorationValue)))
+        return failure();
+
+      structDecorationInfo.emplace_back(true, decoration.value(),
+                                        decorationValue);
+    } else {
+      structDecorationInfo.emplace_back(false, decoration.value(),
+                                        UnitAttr::get(dialect.getContext()));
+    }
+    return success();
+  };
+
+  while (succeeded(parser.parseOptionalComma()))
+    if (failed(parseStructDecoration()))
+      return Type();
+
+  if (failed(parser.parseGreater()))
     return Type();
 
   if (!identifier.empty()) {
     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
-                                     memberDecorationInfo)))
+                                     memberDecorationInfo,
+                                     structDecorationInfo)))
       return Type();
     return idStructTy;
   }
 
-  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+                         structDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -892,7 +927,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
   };
   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                         printMember);
-  os << ")>";
+  os << ")";
+
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+  type.getStructDecorations(decorations);
+  if (!decorations.empty()) {
+    os << ", ";
+    auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+      os << stringifyDecoration(decoration.decoration);
+      if (decoration.hasValue) {
+        os << "=";
+        os.printAttributeWithoutType(decoration.decorationValue);
+      }
+    };
+    llvm::interleaveComma(decorations, os, eachFn);
+  }
+
+  os << ">";
 }
 
 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3799abd6fc743..4bb06b349d040 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -837,12 +837,14 @@ void SampledImageType::getCapabilities(
 /// - for literal structs:
 ///   - a list of member types;
 ///   - a list of member offset info;
-///   - a list of member decoration info.
+///   - a list of member decoration info;
+///   - a list of struct decoration info.
 ///
 /// Identified structures only have a mutable component consisting of:
 /// - a list of member types;
 /// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// Construct a storage object for an identified struct type. A struct type
   /// associated with such storage must call StructType::trySetBody(...) later
@@ -850,6 +852,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(StringRef identifier)
       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+        numStructDecorations(0), structDecorationsInfo(nullptr),
         identifier(identifier) {}
 
   /// Construct a storage object for a literal struct type. A struct type
@@ -857,10 +860,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
-      StructType::MemberDecorationInfo const *memberDecorationsInfo)
+      StructType::MemberDecorationInfo const *memberDecorationsInfo,
+      unsigned numStructDecorations,
+      StructType::StructDecorationInfo const *structDecorationsInfo)
       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
-        memberDecorationsInfo(memberDecorationsInfo) {}
+        memberDecorationsInfo(memberDecorationsInfo),
+        numStructDecorations(numStructDecorations),
+        structDecorationsInfo(structDecorationsInfo) {}
 
   /// A storage key is divided into 2 parts:
   /// - for identified structs:
@@ -869,16 +876,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   ///   - an ArrayRef<Type> for member types;
   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+  ///     info;
+  ///   - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
   ///     info.
   ///
   /// An identified struct type is uniqued only by the first part (field 0)
   /// of the key.
   ///
-  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
-  /// 3) of the key. The identifier field (field 0) must be empty.
+  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+  /// and 4) of the key. The identifier field (field 0) must be empty.
   using KeyTy =
       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
-                 ArrayRef<StructType::MemberDecorationInfo>>;
+                 ArrayRef<StructType::MemberDecorationInfo>,
+                 ArrayRef<StructType::StructDecorationInfo>>;
 
   /// For identified structs, return true if the given key contains the same
   /// identifier.
@@ -892,7 +902,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     }
 
     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
-                        getMemberDecorationsInfo());
+                        getMemberDecorationsInfo(), getStructDecorationsInfo());
   }
 
   /// If the given key contains a non-empty identifier, this method constructs
@@ -939,9 +949,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
 
-    return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
-                          numMemberDecorations, memberDecorationList);
+    const StructType::StructDecorationInfo *structDecorationList = nullptr;
+    unsigned numStructDecorations = 0;
+    if (!std::get<4>(key).empty()) {
+      auto keyStructDecorations = std::get<4>(key);
+      numStructDecorations = keyStructDecorations.size();
+      structDecorationList = allocator.copyInto(keyStructDecorations).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+        keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+        memberDecorationList, numStructDecorations, structDecorationList);
   }
 
   ArrayRef<Type> getMemberTypes() const {
@@ -963,6 +981,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return {};
   }
 
+  ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+    if (structDecorationsInfo)
+      return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+                                                        numStructDecorations);
+    return {};
+  }
+
   StringRef getIdentifier() const { return identifier; }
 
   bool isIdentified() const { return !identifier.empty(); }
@@ -975,17 +1000,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// - If called for an identified struct whose body was set before (through a
   /// call to this method) but with different contents from the passed
   /// arguments.
-  LogicalResult mutate(
-      TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
-      ArrayRef<StructType::OffsetInfo> structOffsetInfo,
-      ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+  LogicalResult
+  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+         ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+         ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+         ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
     if (!isIdentified())
       return failure();
 
     if (memberTypesAndIsBodySet.getInt() &&
         (getMemberTypes() != structMemberTypes ||
          getOffsetInfo() != structOffsetInfo ||
-         getMemberDecorationsInfo() != structMemberDecorationInfo))
+         getMemberDecorationsInfo() != structMemberDecorationInfo ||
+         getStructDecorationsInfo() != structDecorationInfo))
       return failure();
 
     memberTypesAndIsBodySet.setInt(true);
@@ -1009,6 +1036,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
           allocator.copyInto(structMemberDecorationInfo).data();
     }
 
+    if (!structDecorationInfo.empty()) {
+      numStructDecorations = structDecorationInfo.size();
+      structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+    }
+
     return success();
   }
 
@@ -1017,21 +1049,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   unsigned numMembers;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
+  unsigned numStructDecorations;
+  StructType::StructDecorationInfo const *structDecorationsInfo;
   StringRef identifier;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
                 ArrayRef<StructType::OffsetInfo> offsetInfo,
-                ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+                ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+                ArrayRef<StructType::StructDecorationInfo> structDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
-  SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
       memberDecorations);
-  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+  llvm::array_pod_sort(sortedMemberDecorations.begin(),
+                       sortedMemberDecorations.end());
+  SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+      structDecorations);
+  llvm::array_pod_sort(sortedStructDecorations.begin(),
+                       sortedStructDecorations.end());
+
   return Base::get(memberTypes.vec().front().getContext(),
                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
-                   sortedDecorations);
+                   sortedMemberDecorations, sortedStructDecorations);
 }
 
 StructType StructType::getIdentified(MLIRContext *context,
@@ -1041,18 +1082,21 @@ StructType StructType::getIdentified(MLIRContext *context,
 
   return Base::get(context, identifier, ArrayRef<Type>(),
                    ArrayRef<StructType::OffsetInfo>(),
-                   ArrayRef<StructType::MemberDecorationInfo>());
+                   ArrayRef<StructType::MemberDecorationInfo>(),
+                   ArrayRef<StructType::StructDecorationInfo>());
 }
 
 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
   StructType newStructType = Base::get(
       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-      ArrayRef<StructType::MemberDecorationInfo>());
+      ArrayRef<StructType::MemberDecorationInfo>(),
+      ArrayRef<StructType::StructDecorationInfo>());
   // Set an empty body in case this is a identified struct.
   if (newStructType.isIdentified() &&
       failed(newStructType.trySetBody(
           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-          ArrayRef<StructType::MemberDecorationInfo>())))
+          ArrayRef<StructType::MemberDecorationInfo>(),
+          ArrayRef<StructType::StructDecorationInfo>())))
     return StructType();
 
   return newStructType;
@@ -1076,6 +1120,15 @@ TypeRange StructType::getElementTypes() const {
 
 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+  for (StructType::StructDecorationInfo info :
+       getImpl()->getStructDecorationsInfo())
+    if (info.decoration == decoration)
+      return true;
+
+  return false;
+}
+
 uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
   return getImpl()->offsetInfo[index];
@@ -1107,11 +1160,21 @@ void StructType::getMemberDecorations(
   }
 }
 
+void StructType::getStructDecorations(
+    SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+    const {
+  structDecorations.clear();
+  auto implDecorations = getImpl()->getStructDecorationsInfo();
+  structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
 LogicalResult
 StructType::trySetBody(ArrayRef<Type> memberTypes,
                        ArrayRef<OffsetInfo> offsetInfo,
-                       ArrayRef<MemberDecorationInfo> memberDecorations) {
-  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+                       ArrayRef<MemberDecorationInfo> memberDecorations,
+                       ArrayRef<StructDecorationInfo> structDecorations) {
+  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+                      structDecorations);
 }
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1133,6 +1196,11 @@ llvm::hash_code spirv::hash_value(
                             memberDecorationInfo.decoration);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::StructDecorationInfo &structDecorationInfo) {
+  return llvm::hash_value(structDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..c8aa67c8c3b0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
     }
-    // Block decoration does not affect spirv.struct type, but is still stored
-    // for verification.
-    // TODO: Update StructType to contain this information since
-    // it is needed for many validation rules.
     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
     break;
   case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
 
       if (failed(structType.trySetBody(
               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
-              deferredStructIt->memberDecorationsInfo)))
+              deferredStructIt->memberDecorationsInfo,
+              deferredStructIt->structDecorationsInfo)))
         return failure();
 
       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1202,24 +1199,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     }
   }
 
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+  i...
[truncated]

@kuhar kuhar requested review from Hardcode84 and krzysz00 July 21, 2025 14:49
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not seeing any fundamental issues here

An alternative implementation could use ArrayRef of NamedAttributes
or NamedAttrList to store structs decorations, as the deserializer
uses NamedAttributes for decorations. However, using a custom struct
allows us to store the spirv::Decorations directly rather than its
name in a StringRef.
@IgWod-IMG IgWod-IMG force-pushed the img_struct-decorations branch from 9d2abda to 60c2d95 Compare July 24, 2025 14:38
@IgWod-IMG
Copy link
Contributor Author

I have rebased it on #150218 and addressed the outstanding feedback.

@kuhar
Copy link
Member

kuhar commented Jul 24, 2025

@Hardcode84 do you want to review before we merge?

@IgWod-IMG
Copy link
Contributor Author

Ping

@IgWod-IMG IgWod-IMG merged commit 17c1921 into llvm:main Jul 30, 2025
9 checks passed
@IgWod-IMG IgWod-IMG deleted the img_struct-decorations branch July 30, 2025 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants