Skip to content

Commit 195726f

Browse files
authored
Merge pull request #13 from muna-ai/transpile
`muna transpile`
2 parents cc3bd6a + 7ba1a22 commit 195726f

File tree

4 files changed

+194
-57
lines changed

4 files changed

+194
-57
lines changed

Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## 0.0.83
2-
*INCOMPLETE*
2+
+ Added `muna transpile` CLI command to transpile Python functions to C++.
33

44
## 0.0.82
55
+ Minor fixes.

muna/cli/__init__.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from ..version import __version__
1010

1111
from .auth import app as auth_app
12-
from .compile import compile_predictor, triage_predictor
12+
from .compile import compile_function, transpile_function
1313
from .misc import cli_options
1414
from .predictions import create_prediction
1515
from .predictors import archive_predictor, delete_predictor, retrieve_predictor
1616
from .resources import app as resources_app
1717
from .sources import retrieve_source
18+
from .triage import triage_predictor
1819

1920
# Define CLI
2021
typer.main.console_stderr = TracebackMarkupConsole()
@@ -29,39 +30,45 @@
2930
# Add top level options
3031
app.callback()(cli_options)
3132

32-
# Predictions
33+
# Compilation
34+
app.command(
35+
name="transpile",
36+
help="Transpile a Python function to a self-contained C++ library.",
37+
rich_help_panel="Compilation"
38+
)(transpile_function)
39+
app.command(
40+
name="compile",
41+
help="Compile a Python function for deployment.",
42+
rich_help_panel="Compilation"
43+
)(compile_function)
3344
app.command(
3445
name="predict",
35-
help="Make a prediction.",
46+
help="Invoke a compiled Python function.",
3647
context_settings={ "allow_extra_args": True, "ignore_unknown_options": True },
37-
rich_help_panel="Predictions"
48+
rich_help_panel="Compilation"
3849
)(create_prediction)
3950
app.command(
4051
name="source",
4152
help="Retrieve the generated C++ code for a given prediction.",
42-
rich_help_panel="Predictions"
53+
rich_help_panel="Compilation",
54+
hidden=True
4355
)(retrieve_source)
4456

4557
# Predictors
46-
app.command(
47-
name="compile",
48-
help="Create a predictor by compiling a Python function.",
49-
rich_help_panel="Predictors"
50-
)(compile_predictor)
5158
app.command(
5259
name="retrieve",
53-
help="Retrieve a predictor.",
54-
rich_help_panel="Predictors"
60+
help="Retrieve a compiled function.",
61+
rich_help_panel="Functions"
5562
)(retrieve_predictor)
5663
app.command(
5764
name="archive",
58-
help="Archive a predictor." ,
59-
rich_help_panel="Predictors"
65+
help="Archive a compiled function." ,
66+
rich_help_panel="Functions"
6067
)(archive_predictor)
6168
app.command(
6269
name="delete",
63-
help="Delete a predictor.",
64-
rich_help_panel="Predictors"
70+
help="Delete a compiled function.",
71+
rich_help_panel="Functions"
6572
)(delete_predictor)
6673

6774
# Subcommands

muna/cli/compile.py

Lines changed: 124 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,35 @@
66
from importlib.util import module_from_spec, spec_from_file_location
77
from inspect import getmembers, getmodulename, isfunction
88
from pathlib import Path
9+
import platform
910
from pydantic import BaseModel
1011
from rich import print as print_rich
11-
from rich.panel import Panel
1212
import sys
1313
from typer import Argument, Option
14-
from typing import Callable, Literal
15-
from typing_extensions import Annotated
14+
from typing import Annotated, Callable, Literal
1615
from urllib.parse import urlparse, urlunparse
1716

1817
from ..client import MunaAPIError
1918
from ..compile import PredictorSpec
19+
from ..logging import CustomProgress, CustomProgressTask
2020
from ..muna import Muna
2121
from ..sandbox import EntrypointCommand
22-
from ..logging import CustomProgress, CustomProgressTask
22+
from ..types import PredictionResource
2323
from .auth import get_access_key
2424

25-
def compile_predictor(
26-
path: str=Argument(..., help="Predictor path."),
27-
overwrite: bool=Option(False, "--overwrite", help="Whether to delete any existing predictor with the same tag before compiling."),
25+
def compile_function(
26+
path: Annotated[str, Argument(
27+
resolve_path=True,
28+
exists=True,
29+
readable=True,
30+
file_okay=True,
31+
dir_okay=False,
32+
help="Python source path."
33+
)],
34+
overwrite: Annotated[bool, Option(
35+
"--overwrite",
36+
help="Whether to delete any existing predictor with the same tag before compiling.")
37+
]=False,
2838
):
2939
muna = Muna(get_access_key())
3040
path: Path = Path(path).resolve()
@@ -49,7 +59,7 @@ def compile_predictor(
4959
f"have a docstring."
5060
)
5161
spec.description = func.__doc__.strip()
52-
task.finish(f"Loaded prediction function: [bold cyan]{spec.tag}[/bold cyan]")
62+
task.finish(f"Loaded Python function: [bold cyan]{spec.tag}[/bold cyan]")
5363
# Populate
5464
sandbox = spec.sandbox
5565
sandbox.commands.append(entrypoint)
@@ -83,39 +93,85 @@ def compile_predictor(
8393
),
8494
response_type=_LogEvent | _ErrorEvent
8595
):
86-
if isinstance(event, _LogEvent):
87-
task_queue.push_log(event)
88-
elif isinstance(event, _ErrorEvent):
89-
task_queue.push_error(event)
90-
raise CompileError(event.data.error)
96+
match event.event:
97+
case "log":
98+
task_queue.push_log(event)
99+
case "error":
100+
task_queue.push_error(event)
101+
raise CompileError(event.data.error)
91102
predictor_url = _compute_predictor_url(muna.client.api_url, spec.tag)
92103
print_rich(f"\n[bold spring_green3]🎉 Predictor is now being compiled.[/bold spring_green3] Check it out at [link={predictor_url}]{predictor_url}[/link]")
93104

94-
def triage_predictor(
95-
reference_code: Annotated[str, Argument(help="Predictor compilation reference code.")]
105+
def transpile_function(
106+
path: Annotated[Path, Argument(
107+
resolve_path=True,
108+
exists=True,
109+
readable=True,
110+
file_okay=True,
111+
dir_okay=False,
112+
help="Python source path."
113+
)],
114+
output: Annotated[Path, Argument(
115+
resolve_path=True,
116+
exists=False,
117+
writable=True,
118+
help="Output path for generated C++ sources."
119+
)]=Path("cpp")
96120
):
97121
muna = Muna(get_access_key())
98-
error = muna.client.request(
99-
method="GET",
100-
path=f"/predictors/triage?referenceCode={reference_code}",
101-
response_type=_TriagedCompileError
102-
)
103-
user_panel = Panel(
104-
error.user,
105-
title="User Error",
106-
title_align="left",
107-
highlight=True,
108-
border_style="bright_red"
109-
)
110-
internal_panel = Panel(
111-
error.internal,
112-
title="Internal Error",
113-
title_align="left",
114-
highlight=True,
115-
border_style="gold1"
116-
)
117-
print_rich(user_panel)
118-
print_rich(internal_panel)
122+
with CustomProgress():
123+
# Load
124+
with CustomProgressTask(loading_text="Loading predictor...") as task:
125+
func = _load_predictor_func(path)
126+
entrypoint = EntrypointCommand(
127+
from_path=str(path),
128+
to_path=f"./{path.name}",
129+
name=func.__name__
130+
)
131+
spec: PredictorSpec = func.__predictor_spec
132+
spec.targets = (
133+
spec.targets
134+
if spec.targets is not None
135+
else [_get_current_target()]
136+
)
137+
task.finish(f"Loaded Python function: [bold cyan]{func.__module__}.{func.__name__}[/bold cyan]")
138+
# Populate
139+
sandbox = spec.sandbox
140+
sandbox.commands.append(entrypoint)
141+
with CustomProgressTask(loading_text="Uploading sandbox...", done_text="Uploaded sandbox"):
142+
sandbox.populate(muna=muna)
143+
# Compile
144+
with CustomProgressTask(loading_text="Running codegen...", done_text="Completed codegen"):
145+
with ProgressLogQueue() as task_queue:
146+
for event in muna.client.stream(
147+
method="POST",
148+
path=f"/transpile",
149+
body=spec.model_dump(
150+
mode="json",
151+
exclude=spec.model_extra.keys(),
152+
by_alias=True
153+
),
154+
response_type=_LogEvent | _ErrorEvent | _SourceEvent
155+
):
156+
match event.event:
157+
case "log":
158+
task_queue.push_log(event)
159+
case "error":
160+
task_queue.push_error(event)
161+
raise CompileError(event.data.error)
162+
case "sources":
163+
source: _TranspiledSource = event.data[0]
164+
# Write source files
165+
output.mkdir()
166+
_write_file(source.code, dir=output, muna=muna)
167+
_write_file(source.cmake, dir=output, muna=muna)
168+
_write_file(source.readme, dir=output, muna=muna)
169+
_write_file(source.example, dir=output, muna=muna)
170+
if source.resources:
171+
resource_path = output / "resources"
172+
resource_path.mkdir()
173+
for res in source.resources:
174+
_write_file(res.url, name=res.name, dir=resource_path, muna=muna)
119175

120176
def _load_predictor_func(path: str) -> Callable[...,object]:
121177
if "" not in sys.path:
@@ -142,6 +198,27 @@ def _compute_predictor_url(api_url: str, tag: str) -> str:
142198
predictor_url = urlunparse(parsed_url._replace(netloc=netloc, path=f"{tag}"))
143199
return predictor_url
144200

201+
def _get_current_target() -> str:
202+
match (platform.system().lower(), platform.machine().lower()):
203+
case ("darwin", "arm64"): return "arm64-apple-darwin"
204+
case ("linux", "aarch64"): return "aarch64-unknown-linux-gnu"
205+
case ("linux", "x86_64"): return "x86_64-unknown-linux-gnu"
206+
case ("windows", "arm64"): return "aarch64-pc-windows-msvc"
207+
case ("windows", "amd64"): return "x86_64-pc-windows-msvc"
208+
case (system, arch): raise ValueError(f"Cannot transpile because your system target is unsupported: {system} {arch}")
209+
210+
def _write_file(
211+
url: str,
212+
*,
213+
name: str=None,
214+
dir: Path,
215+
muna: Muna,
216+
) -> Path:
217+
name = name or Path(url).name
218+
path = dir / name
219+
muna.client.download(url, path, progress=True)
220+
return path
221+
145222
class _Predictor(BaseModel):
146223
tag: str
147224

@@ -162,13 +239,20 @@ class _ErrorEvent(BaseModel):
162239
event: Literal["error"]
163240
data: _ErrorData
164241

242+
class _TranspiledSource(BaseModel):
243+
code: str
244+
cmake: str
245+
readme: str
246+
example: str
247+
resources: list[PredictionResource]
248+
249+
class _SourceEvent(BaseModel):
250+
event: Literal["sources"]
251+
data: list[_TranspiledSource]
252+
165253
class CompileError(Exception):
166254
pass
167255

168-
class _TriagedCompileError(BaseModel):
169-
user: str
170-
internal: str
171-
172256
class ProgressLogQueue:
173257

174258
def __init__(self):

muna/cli/triage.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Muna
3+
# Copyright © 2026 NatML Inc. All Rights Reserved.
4+
#
5+
6+
from pydantic import BaseModel
7+
from rich import print as print_rich
8+
from rich.panel import Panel
9+
from typer import Argument
10+
from typing import Annotated
11+
12+
from ..muna import Muna
13+
from .auth import get_access_key
14+
15+
def triage_predictor(
16+
reference_code: Annotated[
17+
str,
18+
Argument(help="Predictor compilation reference code.")
19+
]
20+
):
21+
muna = Muna(get_access_key())
22+
error = muna.client.request(
23+
method="GET",
24+
path=f"/predictors/triage?referenceCode={reference_code}",
25+
response_type=_TriagedCompileError
26+
)
27+
user_panel = Panel(
28+
error.user,
29+
title="User Error",
30+
title_align="left",
31+
highlight=True,
32+
border_style="bright_red"
33+
)
34+
internal_panel = Panel(
35+
error.internal,
36+
title="Internal Error",
37+
title_align="left",
38+
highlight=True,
39+
border_style="gold1"
40+
)
41+
print_rich(user_panel)
42+
print_rich(internal_panel)
43+
44+
class _TriagedCompileError(BaseModel):
45+
user: str
46+
internal: str

0 commit comments

Comments
 (0)