Skip to content

Commit 9df8bc9

Browse files
authored
[mlir][spirv][nfc] Refactor member decorations in StructType (#150218)
This patch makes `==` and `<` for MemberDecorationInfo a friend function and removes a `hasValue` field. `decorationValue` is also made an `mlir::Attribute` so `UnitAttr` can be used to represent no-value. This is consistent with how OpDecorate is handled in the deserializer. Using `Attribute` will also enable handling non-integer values, however, there seem to be no such decorations for struct members now.
1 parent 8fcbd06 commit 9df8bc9

File tree

5 files changed

+42
-36
lines changed

5 files changed

+42
-36
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -301,30 +301,33 @@ class StructType
301301

302302
static constexpr StringLiteral name = "spirv.struct";
303303

304-
// Type for specifying the decoration(s) on struct members
304+
// Type for specifying the decoration(s) on struct members.
305+
// If `decorationValue` is UnitAttr then decoration has no
306+
// value.
305307
struct MemberDecorationInfo {
306-
uint32_t memberIndex : 31;
307-
uint32_t hasValue : 1;
308+
uint32_t memberIndex;
308309
Decoration decoration;
309-
uint32_t decorationValue;
310+
Attribute decorationValue;
310311

311-
MemberDecorationInfo(uint32_t index, uint32_t hasValue,
312-
Decoration decoration, uint32_t decorationValue)
313-
: memberIndex(index), hasValue(hasValue), decoration(decoration),
312+
MemberDecorationInfo(uint32_t index, Decoration decoration,
313+
Attribute decorationValue)
314+
: memberIndex(index), decoration(decoration),
314315
decorationValue(decorationValue) {}
315316

316-
bool operator==(const MemberDecorationInfo &other) const {
317-
return (this->memberIndex == other.memberIndex) &&
318-
(this->decoration == other.decoration) &&
319-
(this->decorationValue == other.decorationValue);
317+
friend bool operator==(const MemberDecorationInfo &lhs,
318+
const MemberDecorationInfo &rhs) {
319+
return lhs.memberIndex == rhs.memberIndex &&
320+
lhs.decoration == rhs.decoration &&
321+
lhs.decorationValue == rhs.decorationValue;
320322
}
321323

322-
bool operator<(const MemberDecorationInfo &other) const {
323-
return this->memberIndex < other.memberIndex ||
324-
(this->memberIndex == other.memberIndex &&
325-
static_cast<uint32_t>(this->decoration) <
326-
static_cast<uint32_t>(other.decoration));
324+
friend bool operator<(const MemberDecorationInfo &lhs,
325+
const MemberDecorationInfo &rhs) {
326+
return std::tuple(lhs.memberIndex, llvm::to_underlying(lhs.decoration)) <
327+
std::tuple(rhs.memberIndex, llvm::to_underlying(rhs.decoration));
327328
}
329+
330+
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
328331
};
329332

330333
/// Construct a literal StructType with at least one member.

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -665,19 +665,17 @@ static ParseResult parseStructMemberDecorations(
665665

666666
// Parse member decoration value if it exists.
667667
if (succeeded(parser.parseOptionalEqual())) {
668-
auto memberDecorationValue =
669-
parseAndVerifyInteger<uint32_t>(dialect, parser);
670-
671-
if (!memberDecorationValue)
668+
Attribute memberDecorationValue;
669+
if (failed(parser.parseAttribute(memberDecorationValue)))
672670
return failure();
673671

674672
memberDecorationInfo.emplace_back(
675-
static_cast<uint32_t>(memberTypes.size() - 1), 1,
676-
memberDecoration.value(), memberDecorationValue.value());
673+
static_cast<uint32_t>(memberTypes.size() - 1),
674+
memberDecoration.value(), memberDecorationValue);
677675
} else {
678676
memberDecorationInfo.emplace_back(
679-
static_cast<uint32_t>(memberTypes.size() - 1), 0,
680-
memberDecoration.value(), 0);
677+
static_cast<uint32_t>(memberTypes.size() - 1),
678+
memberDecoration.value(), UnitAttr::get(dialect.getContext()));
681679
}
682680
return success();
683681
};
@@ -882,8 +880,9 @@ static void print(StructType type, DialectAsmPrinter &os) {
882880
}
883881
auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
884882
os << stringifyDecoration(decoration.decoration);
885-
if (decoration.hasValue) {
886-
os << "=" << decoration.decorationValue;
883+
if (decoration.hasValue()) {
884+
os << "=";
885+
os.printAttributeWithoutType(decoration.decorationValue);
887886
}
888887
};
889888
llvm::interleaveComma(decorations, os, eachFn);

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,13 +1188,14 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
11881188
}
11891189
offsetInfo[memberIndex] = memberDecoration.second[0];
11901190
} else {
1191+
auto intType = mlir::IntegerType::get(context, 32);
11911192
if (!memberDecoration.second.empty()) {
1192-
memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1193-
memberDecoration.first,
1194-
memberDecoration.second[0]);
1193+
memberDecorationsInfo.emplace_back(
1194+
memberIndex, memberDecoration.first,
1195+
IntegerAttr::get(intType, memberDecoration.second[0]));
11951196
} else {
1196-
memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1197-
memberDecoration.first, 0);
1197+
memberDecorationsInfo.emplace_back(
1198+
memberIndex, memberDecoration.first, UnitAttr::get(context));
11981199
}
11991200
}
12001201
}

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ LogicalResult Serializer::processMemberDecoration(
406406
SmallVector<uint32_t, 4> args(
407407
{structID, memberDecoration.memberIndex,
408408
static_cast<uint32_t>(memberDecoration.decoration)});
409-
if (memberDecoration.hasValue) {
410-
args.push_back(memberDecoration.decorationValue);
409+
if (memberDecoration.hasValue()) {
410+
args.push_back(
411+
cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
411412
}
412413
encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
413414
return success();
@@ -666,10 +667,12 @@ LogicalResult Serializer::prepareBasicType(
666667
}
667668
operands.push_back(elementTypeID);
668669
if (hasOffset) {
670+
auto intType = IntegerType::get(structType.getContext(), 32);
669671
// Decorate each struct member with an offset
670672
spirv::StructType::MemberDecorationInfo offsetDecoration{
671-
elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
672-
static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
673+
elementIndex, spirv::Decoration::Offset,
674+
IntegerAttr::get(intType,
675+
structType.getMemberOffset(elementIndex))};
673676
if (failed(processMemberDecoration(resultID, offsetDecoration))) {
674677
return emitError(loc, "cannot decorate ")
675678
<< elementIndex << "-th member of " << structType

mlir/test/Dialect/SPIRV/IR/types.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func.func private @struct_type_missing_comma(!spirv.struct<(!spirv.matrix<3 x ve
353353

354354
// -----
355355

356-
// expected-error @+1 {{expected integer value}}
356+
// expected-error @+1 {{expected attribute value}}
357357
func.func private @struct_missing_member_decorator_value(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=])>)
358358

359359
// -----

0 commit comments

Comments
 (0)