diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td index 1c5567f54a5f4..cd23b6a2effb9 100644 --- a/mlir/test/mlir-tblgen/enums-python-bindings.td +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -62,12 +62,15 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]> // CHECK: def _myenum64(x, context): // CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x)) +def User : I32BitEnumAttrCaseBit<"User", 0, "user">; +def Group : I32BitEnumAttrCaseBit<"Group", 1, "group">; +def Other : I32BitEnumAttrCaseBit<"Other", 2, "other">; + def TestBitEnum - : I32BitEnumAttr<"TestBitEnum", "", [ - I32BitEnumAttrCaseBit<"User", 0, "user">, - I32BitEnumAttrCaseBit<"Group", 1, "group">, - I32BitEnumAttrCaseBit<"Other", 2, "other">, - ]> { + : I32BitEnumAttr< + "TestBitEnum", "", + [User, Group, Other, + I32BitEnumAttrCaseGroup<"Any", [User, Group, Other], "any">]> { let genSpecializedAttr = 0; let separator = " | "; } @@ -79,9 +82,10 @@ def TestBitEnum_Attr : EnumAttr; // CHECK: User = 1 // CHECK: Group = 2 // CHECK: Other = 4 +// CHECK: Any = 7 // CHECK: def __iter__(self): -// CHECK: return iter([case for case in type(self) if (self & case) is case]) +// CHECK: return iter([case for case in type(self) if (self & case) is case and self is not case]) // CHECK: def __len__(self): // CHECK: return bin(self).count("1") @@ -94,6 +98,8 @@ def TestBitEnum_Attr : EnumAttr; // CHECK: return "group" // CHECK: if self is TestBitEnum.Other: // CHECK: return "other" +// CHECK: if self is TestBitEnum.Any: +// CHECK: return "any" // CHECK: raise ValueError("Unknown TestBitEnum enum entry.") // CHECK: @register_attribute_builder("TestBitEnum") diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index 8e2d6114e48eb..acc9b61d7121c 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -64,7 +64,7 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) { if (enumInfo.isBitEnum()) { os << formatv(" def __iter__(self):\n" " return iter([case for case in type(self) if " - "(self & case) is case])\n"); + "(self & case) is case and self is not case])\n"); os << formatv(" def __len__(self):\n" " return bin(self).count(\"1\")\n"); os << "\n";