Skip to content

Commit f162083

Browse files
committed
test_compile: Add test for compile_multi_file().
1 parent 7c0e177 commit f162083

File tree

1 file changed

+88
-1
lines changed

1 file changed

+88
-1
lines changed

tests/test_compile.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
# Copyright (c) 2022 The Pybricks Authors
33

44

5+
import contextlib
56
import os
67
import struct
78
from tempfile import TemporaryDirectory
89

910
import pytest
1011

11-
from pybricksdev.compile import compile_file
12+
from pybricksdev.compile import compile_file, compile_multi_file
1213

1314

1415
@pytest.mark.parametrize("abi", [5, 6])
@@ -38,3 +39,89 @@ async def test_compile_file_invalid_abi():
3839

3940
with pytest.raises(ValueError, match="mpy_version must be 5 or 6"):
4041
await compile_file(os.path.dirname(f.name), os.path.basename(f.name), abi=4)
42+
43+
44+
@pytest.mark.parametrize("abi", [5, 6])
45+
@pytest.mark.asyncio
46+
async def test_compile_multi_file(abi: int):
47+
with TemporaryDirectory() as temp_dir, contextlib.chdir(temp_dir):
48+
with open(os.path.join(temp_dir, "test.py"), "w", encoding="utf-8") as f:
49+
f.writelines(
50+
[
51+
"from pybricks import version\n",
52+
"import test1\n",
53+
"from test2 import thing2\n",
54+
"from nested.test3 import thing3\n",
55+
]
56+
)
57+
58+
with open(os.path.join(temp_dir, "test1.py"), "w", encoding="utf-8") as f1:
59+
f1.write("thing1 = 'thing1'\n")
60+
61+
with open(os.path.join(temp_dir, "test2.py"), "w", encoding="utf-8") as f2:
62+
f2.write("thing2 = 'thing2'\n")
63+
64+
os.mkdir("nested")
65+
66+
# Work around bug where ModuleFinder can't handle implicit namespace
67+
# packages by adding an __init__.py file.
68+
with open(
69+
os.path.join(temp_dir, "nested", "__init__.py"), "w", encoding="utf-8"
70+
) as init:
71+
init.write("")
72+
73+
with open(
74+
os.path.join(temp_dir, "nested", "test3.py"), "w", encoding="utf-8"
75+
) as f3:
76+
f3.write("thing3 = 'thing3'\n")
77+
78+
multi_mpy = await compile_multi_file("test.py", abi)
79+
pos = 0
80+
81+
def unpack_mpy(data: bytes) -> tuple[bytes, bytes]:
82+
nonlocal pos
83+
84+
size = struct.unpack_from("<I", multi_mpy, pos)[0]
85+
pos += 4
86+
87+
name = bytearray()
88+
89+
# read zero-terminated string
90+
while multi_mpy[pos] != 0:
91+
name.append(multi_mpy[pos])
92+
pos += 1
93+
94+
pos += 1 # skip 0 byte
95+
96+
mpy = multi_mpy[pos : pos + size]
97+
pos += size
98+
99+
return name, mpy
100+
101+
name1, mpy1 = unpack_mpy(multi_mpy)
102+
name2, mpy2 = unpack_mpy(multi_mpy)
103+
name3, mpy3 = unpack_mpy(multi_mpy)
104+
name4, mpy4 = unpack_mpy(multi_mpy)
105+
name5, mpy5 = unpack_mpy(multi_mpy)
106+
107+
assert pos == len(multi_mpy)
108+
109+
assert name1.decode() == "__main__"
110+
assert name2.decode() == "test1"
111+
assert name3.decode() == "test2"
112+
assert name4.decode() == "nested"
113+
assert name5.decode() == "nested.test3"
114+
115+
def check_mpy(mpy: bytes) -> None:
116+
magic, abi_ver, flags, int_bits = struct.unpack_from("<BBBB", mpy)
117+
118+
assert chr(magic) == "M"
119+
assert abi_ver == abi
120+
assert flags == 0
121+
assert int_bits == 31
122+
123+
check_mpy(mpy1)
124+
check_mpy(mpy2)
125+
check_mpy(mpy3)
126+
check_mpy(mpy4)
127+
check_mpy(mpy5)

0 commit comments

Comments
 (0)