Skip to content

MLIR Enum Python bindings infinite recursionΒ #151584

@nsmithtt

Description

@nsmithtt

An infinite recursion bug arises when using I32BitEnumAttrCaseGroup with python bindings.

Example repro

Example tablegen:

def TTCore_ChipCapabilityPCIE : I32BitEnumAttrCaseBit<"PCIE", 0, "pcie">;
def TTCore_ChipCapabilityHostMMIO : I32BitEnumAttrCaseBit<"HostMMIO", 1, "host_mmio">;
def TTCore_ChipCapabilityAll : I32BitEnumAttrCaseGroup<"All",
    [TTCore_ChipCapabilityPCIE, TTCore_ChipCapabilityHostMMIO], "all">;

def TTCore_ChipCapability : I32BitEnumAttr<"ChipCapability", "TT Chip Capabilities",
                           [
                            TTCore_ChipCapabilityPCIE,
                            TTCore_ChipCapabilityHostMMIO,
                            TTCore_ChipCapabilityAll,
                           ]> {
  let genSpecializedAttr = 1;
  let cppNamespace = "::mlir::tt::ttcore";
}

Generates the following python binding:

class ChipCapability(IntFlag):
    """TT Chip Capabilities"""

    PCIE = 1
    HostMMIO = 2
    All = 3

    def __iter__(self):
        return iter([case for case in type(self) if (self & case) is case])
    def __len__(self):
        return bin(self).count("1")

    def __str__(self):
        if len(self) > 1:
            return "|".join(map(str, self))
        if self is ChipCapability.PCIE:
            return "pcie"
        if self is ChipCapability.HostMMIO:
            return "host_mmio"
        if self is ChipCapability.All:
            return "all"
        raise ValueError("Unknown ChipCapability enum entry.")

The following sequence results in infinite recursion:

  1. Call to __str__, we take the first branch and map over each element of the IntFlag enum class.
  2. We go into __iter__. It's especially useful to print(list(case for case in type(self))) right here, we can see:
[<ChipCapability.PCIE: 1>, <ChipCapability.HostMMIO: 2>, <ChipCapability.All: 3>]
  1. The existence of <ChipCapability.All: 3> causes this to be a valid case to be returned from __iter__ which generates infinite recursion.

Proposed Fix

Proposed fix is to filter the iteration when the case is not equal to self.

    def __iter__(self):
        return iter([case for case in type(self) if (self & case) is case and self is not case])

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions