Skip to content

Commit acb826e

Browse files
authored
[MLIR][TBLGen] Added compound assignment operator for any BitEnum (#160840)
## Details: - Added missing compound assignment operators `|=`, `&=`, `^=` to `mlir-tblgen` - Replaced the arithmetic operators with added assignment operators for `BitEnum` in the transformations - Updated related documentation ## Tickets: - Closes #158098
1 parent e2d5efd commit acb826e

File tree

5 files changed

+25
-4
lines changed

5 files changed

+25
-4
lines changed

mlir/docs/DefiningDialects/Operations.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,15 @@ inline constexpr MyBitEnum operator&(MyBitEnum a, MyBitEnum b) {
16491649
inline constexpr MyBitEnum operator^(MyBitEnum a, MyBitEnum b) {
16501650
return static_cast<MyBitEnum>(static_cast<uint32_t>(a) ^ static_cast<uint32_t>(b));
16511651
}
1652+
inline constexpr MyBitEnum &operator|=(MyBitEnum &a, MyBitEnum b) {
1653+
return a = a | b;
1654+
}
1655+
inline constexpr MyBitEnum &operator&=(MyBitEnum &a, MyBitEnum b) {
1656+
return a = a & b;
1657+
}
1658+
inline constexpr MyBitEnum &operator^=(MyBitEnum &a, MyBitEnum b) {
1659+
return a = a ^ b;
1660+
}
16521661
inline constexpr MyBitEnum operator~(MyBitEnum bits) {
16531662
// Ensure only bits that can be present in the enum are set
16541663
return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
512512
if (!sizeInBytes.has_value())
513513
return failure();
514514

515-
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
515+
memoryAccess |= spirv::MemoryAccess::Aligned;
516516
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
517517
auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518518
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ struct VectorLoadOpConverter final
753753
spirv::MemoryAccessAttr memoryAccessAttr;
754754
IntegerAttr alignmentAttr;
755755
if (alignment.has_value()) {
756-
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
756+
memoryAccess |= spirv::MemoryAccess::Aligned;
757757
memoryAccessAttr =
758758
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
759759
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
@@ -822,7 +822,7 @@ struct VectorStoreOpConverter final
822822
spirv::MemoryAccessAttr memoryAccessAttr;
823823
IntegerAttr alignmentAttr;
824824
if (alignment.has_value()) {
825-
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
825+
memoryAccess |= spirv::MemoryAccess::Aligned;
826826
memoryAccessAttr =
827827
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
828828
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());

mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ static void addScopeToFunction(LLVM::LLVMFuncOp llvmFunc,
7777
auto subprogramFlags = LLVM::DISubprogramFlags::Optimized;
7878
if (!llvmFunc.isExternal()) {
7979
id = DistinctAttr::create(UnitAttr::get(context));
80-
subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition;
80+
subprogramFlags |= LLVM::DISubprogramFlags::Definition;
8181
} else {
8282
compileUnitAttr = {};
8383
}

mlir/tools/mlir-tblgen/EnumsGen.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumCase> cases) {
364364
// inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
365365
// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
366366
// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
367+
// inline constexpr <enum-type> &operator|=(<enum-type> &a, <enum-type> b);
368+
// inline constexpr <enum-type> &operator&=(<enum-type> &a, <enum-type> b);
369+
// inline constexpr <enum-type> &operator^=(<enum-type> &a, <enum-type> b);
367370
// inline constexpr <enum-type> operator~(<enum-type> bits);
368371
// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
369372
// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
@@ -385,6 +388,15 @@ inline constexpr {0} operator&({0} a, {0} b) {{
385388
inline constexpr {0} operator^({0} a, {0} b) {{
386389
return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
387390
}
391+
inline constexpr {0} &operator|=({0} &a, {0} b) {{
392+
return a = a | b;
393+
}
394+
inline constexpr {0} &operator&=({0} &a, {0} b) {{
395+
return a = a & b;
396+
}
397+
inline constexpr {0} &operator^=({0} &a, {0} b) {{
398+
return a = a ^ b;
399+
}
388400
inline constexpr {0} operator~({0} bits) {{
389401
// Ensure only bits that can be present in the enum are set
390402
return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));

0 commit comments

Comments
 (0)