Skip to content

Commit 3709b0e

Browse files
use new safe key
1 parent 65aefba commit 3709b0e

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

mypyc/codegen/emit.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,19 +204,20 @@ def object_annotation(self, obj: object, line: str) -> str:
204204
205205
If it contains illegal characters, an empty string is returned."""
206206
line_width = self._indent + len(line)
207+
208+
# temporarily override pprint._safe_key
209+
default_safe_key = pprint._safe_key
210+
pprint._safe_key = _mypyc_safe_key
211+
212+
# pretty print the object
207213
formatted = pprint.pformat(obj, compact=True, width=max(90 - line_width, 20))
214+
215+
# replace the _safe_key
216+
pprint._safe_key = default_safe_key
217+
208218
if any(x in formatted for x in ("/*", "*/", "\0")):
209219
return ""
210220

211-
# make frozenset annotations deterministic
212-
if formatted.startswith("frozenset({"):
213-
frozenset_items = formatted[11:-2]
214-
# if our frozenset contains another frozenset or a tuple, we will need better logic
215-
# here, but this rudimentary logic will still vastly improve codegen determinism.
216-
if "(" not in frozenset_items:
217-
sorted_items = ", ".join(sorted(frozenset_items.split(", ")))
218-
formatted = "frozenset({" + sorted_items + "})"
219-
220221
if "\n" in formatted:
221222
first_line, rest = formatted.split("\n", maxsplit=1)
222223
comment_continued = textwrap.indent(rest, (line_width + 3) * " ")
@@ -1239,8 +1240,10 @@ def c_array_initializer(components: list[str], *, indented: bool = False) -> str
12391240

12401241
class _mypyc_safe_key(pprint._safe_key):
12411242
"""A custom sort key implementation for pprint that makes the output deterministic
1242-
for all literal types supported by mypyc
1243-
"""
1243+
for all literal types supported by mypyc.
12441244
1245-
def __lt__(self, other: _mypyc_safe_key) -> bool:
1245+
This is NOT safe for use as a sort key for other types, so we MUST replace the
1246+
original pprint._safe_key once we've pprinted our object.
1247+
"""
1248+
def __lt__(self, other: "_mypyc_safe_key") -> bool:
12461249
return str(type(self.obj)) + repr(self.obj) < str(type(other.obj)) + repr(other.obj)

0 commit comments

Comments
 (0)