Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,34 @@ class StructType
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
};

// Type for specifying the decoration(s) on the struct itself.
struct StructDecorationInfo {
Decoration decoration;
Attribute decorationValue;

StructDecorationInfo(Decoration decoration, Attribute decorationValue)
: decoration(decoration), decorationValue(decorationValue) {}

friend bool operator==(const StructDecorationInfo &lhs,
const StructDecorationInfo &rhs) {
return lhs.decoration == rhs.decoration &&
lhs.decorationValue == rhs.decorationValue;
}

friend bool operator<(const StructDecorationInfo &lhs,
const StructDecorationInfo &rhs) {
return llvm::to_underlying(lhs.decoration) <
llvm::to_underlying(rhs.decoration);
}

bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
};

/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});

/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
Expand Down Expand Up @@ -367,6 +391,9 @@ class StructType

bool hasOffset() const;

/// Returns true if the struct has a specified decoration.
bool hasDecoration(spirv::Decoration decoration) const;

uint64_t getMemberOffset(unsigned) const;

// Returns in `memberDecorations` the Decorations (apart from Offset)
Expand All @@ -380,12 +407,18 @@ class StructType
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;

// Returns in `structDecorations` the Decorations associated with the
// StructType.
void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
&structDecorations) const;

/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
Expand All @@ -396,6 +429,9 @@ class StructType
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);

llvm::hash_code
hash_value(const StructType::StructDecorationInfo &structDecorationInfo);

// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
Expand Down
59 changes: 54 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
// `)>`
// `)`
// (`,` struct-decoration)?
// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
Expand Down Expand Up @@ -765,17 +767,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}

if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
if (failed(parser.parseRParen()))
return Type();

SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;

auto parseStructDecoration = [&]() {
std::optional<spirv::Decoration> decoration =
parseAndVerify<spirv::Decoration>(dialect, parser);
if (!decoration)
return failure();

// Parse decoration value if it exists.
if (succeeded(parser.parseOptionalEqual())) {
Attribute decorationValue;
if (failed(parser.parseAttribute(decorationValue)))
return failure();

structDecorationInfo.emplace_back(decoration.value(), decorationValue);
} else {
structDecorationInfo.emplace_back(decoration.value(),
UnitAttr::get(dialect.getContext()));
}
return success();
};

while (succeeded(parser.parseOptionalComma()))
if (failed(parseStructDecoration()))
return Type();

if (failed(parser.parseGreater()))
return Type();

if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
memberDecorationInfo)))
memberDecorationInfo,
structDecorationInfo)))
return Type();
return idStructTy;
}

return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
structDecorationInfo);
}

// spirv-type ::= array-type
Expand Down Expand Up @@ -891,7 +924,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
os << ")>";
os << ")";

SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
type.getStructDecorations(decorations);
if (!decorations.empty()) {
os << ", ";
auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
os << stringifyDecoration(decoration.decoration);
if (decoration.hasValue()) {
os << "=";
os.printAttributeWithoutType(decoration.decorationValue);
}
};
llvm::interleaveComma(decorations, os, eachFn);
}

os << ">";
}

static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
Expand Down
Loading