Skip to content

Commit 4f295ac

Browse files
committed
[mypyc] Add support for C string literals in the IR
Previously only Python str and bytes literals were supported, but sometimes we want zero-terminated C string literals instead. They don't need to be allocated from heap and are usually stored in a read-only data section, so they are more efficient in some use cases.
1 parent 4980ae5 commit 4f295ac

File tree

5 files changed

+76
-4
lines changed

5 files changed

+76
-4
lines changed

mypyc/codegen/emitfunc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Cast,
3434
ComparisonOp,
3535
ControlOp,
36+
CString,
3637
DecRef,
3738
Extend,
3839
Float,
@@ -843,6 +844,8 @@ def reg(self, reg: Value) -> str:
843844
elif r == "nan":
844845
return "NAN"
845846
return r
847+
elif isinstance(reg, CString):
848+
return '"' + encode_c_string_literal(reg.value) + '"'
846849
else:
847850
return self.emitter.reg(reg)
848851

@@ -904,3 +907,26 @@ def emit_unsigned_int_cast(self, type: RType) -> str:
904907
return "(uint64_t)"
905908
else:
906909
return ""
910+
911+
912+
_translation_table: Final[dict[int, str]] = {}
913+
914+
915+
def encode_c_string_literal(b: bytes) -> str:
916+
if not _translation_table:
917+
# Initialize the translation table on the first call.
918+
d = {
919+
ord("\n"): "\\n",
920+
ord("\r"): "\\r",
921+
ord("\t"): "\\t",
922+
ord('"'): '\\"',
923+
ord("\\"): "\\\\",
924+
}
925+
for i in range(256):
926+
if i not in d:
927+
if i < 32 or i >= 127:
928+
d[i] = "\\x%.2x" % i
929+
else:
930+
d[i] = chr(i)
931+
_translation_table.update(str.maketrans(d))
932+
return b.decode("latin1").translate(_translation_table)

mypyc/ir/ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class to enable the new behavior. Sometimes adding a new abstract
3939
RVoid,
4040
bit_rprimitive,
4141
bool_rprimitive,
42+
cstring_rprimitive,
4243
float_rprimitive,
4344
int_rprimitive,
4445
is_bit_rprimitive,
@@ -230,6 +231,20 @@ def __init__(self, value: float, line: int = -1) -> None:
230231
self.line = line
231232

232233

234+
@final
235+
class CString(Value):
236+
"""C string literal (zero-terminated).
237+
238+
You can also include zero values in the value, but then you'll need to track
239+
the length of the string separately.
240+
"""
241+
242+
def __init__(self, value: bytes, line: int = -1) -> None:
243+
self.value = value
244+
self.type = cstring_rprimitive
245+
self.line = line
246+
247+
233248
class Op(Value):
234249
"""Abstract base class for all IR operations.
235250

mypyc/ir/pprint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Cast,
2222
ComparisonOp,
2323
ControlOp,
24+
CString,
2425
DecRef,
2526
Extend,
2627
Float,
@@ -327,6 +328,8 @@ def format(self, fmt: str, *args: Any) -> str:
327328
result.append(str(arg.value))
328329
elif isinstance(arg, Float):
329330
result.append(repr(arg.value))
331+
elif isinstance(arg, CString):
332+
result.append(f"CString({arg.value!r})")
330333
else:
331334
result.append(self.names[arg])
332335
elif typespec == "d":

mypyc/ir/rtypes.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,11 @@ def __init__(
254254
elif ctype == "CPyPtr":
255255
# TODO: Invent an overlapping error value?
256256
self.c_undefined = "0"
257-
elif ctype == "PyObject *":
258-
# Boxed types use the null pointer as the error value.
257+
elif ctype.endswith("*"):
258+
# Boxed and pointer types use the null pointer as the error value.
259259
self.c_undefined = "NULL"
260260
elif ctype == "char":
261261
self.c_undefined = "2"
262-
elif ctype in ("PyObject **", "void *"):
263-
self.c_undefined = "NULL"
264262
elif ctype == "double":
265263
self.c_undefined = "-113.0"
266264
elif ctype in ("uint8_t", "uint16_t", "uint32_t", "uint64_t"):
@@ -445,6 +443,10 @@ def __hash__(self) -> int:
445443
"c_ptr", is_unboxed=False, is_refcounted=False, ctype="void *"
446444
)
447445

446+
cstring_rprimitive: Final = RPrimitive(
447+
"cstring", is_unboxed=True, is_refcounted=False, ctype="const char *"
448+
)
449+
448450
# The type corresponding to mypyc.common.BITMAP_TYPE
449451
bitmap_rprimitive: Final = uint32_rprimitive
450452

mypyc/test/test_emitfunc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
CallC,
2020
Cast,
2121
ComparisonOp,
22+
CString,
2223
DecRef,
2324
Extend,
2425
GetAttr,
@@ -49,6 +50,7 @@
4950
RType,
5051
bool_rprimitive,
5152
c_int_rprimitive,
53+
cstring_rprimitive,
5254
dict_rprimitive,
5355
int32_rprimitive,
5456
int64_rprimitive,
@@ -824,6 +826,30 @@ def test_inc_ref_int_literal(self) -> None:
824826
b = LoadLiteral(x, object_rprimitive)
825827
self.assert_emit([b, IncRef(b)], "CPy_INCREF(cpy_r_r0);")
826828

829+
def test_c_string(self) -> None:
830+
s = Register(cstring_rprimitive, "s")
831+
self.assert_emit(Assign(s, CString(b"foo")), """cpy_r_s = "foo";""")
832+
self.assert_emit(Assign(s, CString(b'fo "o')), r"""cpy_r_s = "fo \"o";""")
833+
self.assert_emit(Assign(s, CString(b"\x00")), r"""cpy_r_s = "\x00";""")
834+
self.assert_emit(Assign(s, CString(b"\\")), r"""cpy_r_s = "\\";""")
835+
for i in range(256):
836+
b = bytes([i])
837+
if b == b"\n":
838+
target = "\\n"
839+
elif b == b"\r":
840+
target = "\\r"
841+
elif b == b"\t":
842+
target = "\\t"
843+
elif b == b'"':
844+
target = '\\"'
845+
elif b == b"\\":
846+
target = "\\\\"
847+
elif i < 32 or i >= 127:
848+
target = "\\x%.2x" % i
849+
else:
850+
target = b.decode("ascii")
851+
self.assert_emit(Assign(s, CString(b)), f'cpy_r_s = "{target}";')
852+
827853
def assert_emit(
828854
self,
829855
op: Op | list[Op],

0 commit comments

Comments
 (0)