66from importlib .util import module_from_spec , spec_from_file_location
77from inspect import getmembers , getmodulename , isfunction
88from pathlib import Path
9+ import platform
910from pydantic import BaseModel
1011from rich import print as print_rich
11- from rich .panel import Panel
1212import sys
1313from typer import Argument , Option
14- from typing import Callable , Literal
15- from typing_extensions import Annotated
14+ from typing import Annotated , Callable , Literal
1615from urllib .parse import urlparse , urlunparse
1716
1817from ..client import MunaAPIError
1918from ..compile import PredictorSpec
19+ from ..logging import CustomProgress , CustomProgressTask
2020from ..muna import Muna
2121from ..sandbox import EntrypointCommand
22- from ..logging import CustomProgress , CustomProgressTask
22+ from ..types import PredictionResource
2323from .auth import get_access_key
2424
25- def compile_predictor (
26- path : str = Argument (..., help = "Predictor path." ),
27- overwrite : bool = Option (False , "--overwrite" , help = "Whether to delete any existing predictor with the same tag before compiling." ),
25+ def compile_function (
26+ path : Annotated [str , Argument (
27+ resolve_path = True ,
28+ exists = True ,
29+ readable = True ,
30+ file_okay = True ,
31+ dir_okay = False ,
32+ help = "Python source path."
33+ )],
34+ overwrite : Annotated [bool , Option (
35+ "--overwrite" ,
36+ help = "Whether to delete any existing predictor with the same tag before compiling." )
37+ ]= False ,
2838):
2939 muna = Muna (get_access_key ())
3040 path : Path = Path (path ).resolve ()
@@ -49,7 +59,7 @@ def compile_predictor(
4959 f"have a docstring."
5060 )
5161 spec .description = func .__doc__ .strip ()
52- task .finish (f"Loaded prediction function: [bold cyan]{ spec .tag } [/bold cyan]" )
62+ task .finish (f"Loaded Python function: [bold cyan]{ spec .tag } [/bold cyan]" )
5363 # Populate
5464 sandbox = spec .sandbox
5565 sandbox .commands .append (entrypoint )
@@ -83,39 +93,85 @@ def compile_predictor(
8393 ),
8494 response_type = _LogEvent | _ErrorEvent
8595 ):
86- if isinstance (event , _LogEvent ):
87- task_queue .push_log (event )
88- elif isinstance (event , _ErrorEvent ):
89- task_queue .push_error (event )
90- raise CompileError (event .data .error )
96+ match event .event :
97+ case "log" :
98+ task_queue .push_log (event )
99+ case "error" :
100+ task_queue .push_error (event )
101+ raise CompileError (event .data .error )
91102 predictor_url = _compute_predictor_url (muna .client .api_url , spec .tag )
92103 print_rich (f"\n [bold spring_green3]🎉 Predictor is now being compiled.[/bold spring_green3] Check it out at [link={ predictor_url } ]{ predictor_url } [/link]" )
93104
94- def triage_predictor (
95- reference_code : Annotated [str , Argument (help = "Predictor compilation reference code." )]
105+ def transpile_function (
106+ path : Annotated [Path , Argument (
107+ resolve_path = True ,
108+ exists = True ,
109+ readable = True ,
110+ file_okay = True ,
111+ dir_okay = False ,
112+ help = "Python source path."
113+ )],
114+ output : Annotated [Path , Argument (
115+ resolve_path = True ,
116+ exists = False ,
117+ writable = True ,
118+ help = "Output path for generated C++ sources."
119+ )]= Path ("cpp" )
96120):
97121 muna = Muna (get_access_key ())
98- error = muna .client .request (
99- method = "GET" ,
100- path = f"/predictors/triage?referenceCode={ reference_code } " ,
101- response_type = _TriagedCompileError
102- )
103- user_panel = Panel (
104- error .user ,
105- title = "User Error" ,
106- title_align = "left" ,
107- highlight = True ,
108- border_style = "bright_red"
109- )
110- internal_panel = Panel (
111- error .internal ,
112- title = "Internal Error" ,
113- title_align = "left" ,
114- highlight = True ,
115- border_style = "gold1"
116- )
117- print_rich (user_panel )
118- print_rich (internal_panel )
122+ with CustomProgress ():
123+ # Load
124+ with CustomProgressTask (loading_text = "Loading predictor..." ) as task :
125+ func = _load_predictor_func (path )
126+ entrypoint = EntrypointCommand (
127+ from_path = str (path ),
128+ to_path = f"./{ path .name } " ,
129+ name = func .__name__
130+ )
131+ spec : PredictorSpec = func .__predictor_spec
132+ spec .targets = (
133+ spec .targets
134+ if spec .targets is not None
135+ else [_get_current_target ()]
136+ )
137+ task .finish (f"Loaded Python function: [bold cyan]{ func .__module__ } .{ func .__name__ } [/bold cyan]" )
138+ # Populate
139+ sandbox = spec .sandbox
140+ sandbox .commands .append (entrypoint )
141+ with CustomProgressTask (loading_text = "Uploading sandbox..." , done_text = "Uploaded sandbox" ):
142+ sandbox .populate (muna = muna )
143+ # Compile
144+ with CustomProgressTask (loading_text = "Running codegen..." , done_text = "Completed codegen" ):
145+ with ProgressLogQueue () as task_queue :
146+ for event in muna .client .stream (
147+ method = "POST" ,
148+ path = f"/transpile" ,
149+ body = spec .model_dump (
150+ mode = "json" ,
151+ exclude = spec .model_extra .keys (),
152+ by_alias = True
153+ ),
154+ response_type = _LogEvent | _ErrorEvent | _SourceEvent
155+ ):
156+ match event .event :
157+ case "log" :
158+ task_queue .push_log (event )
159+ case "error" :
160+ task_queue .push_error (event )
161+ raise CompileError (event .data .error )
162+ case "sources" :
163+ source : _TranspiledSource = event .data [0 ]
164+ # Write source files
165+ output .mkdir ()
166+ _write_file (source .code , dir = output , muna = muna )
167+ _write_file (source .cmake , dir = output , muna = muna )
168+ _write_file (source .readme , dir = output , muna = muna )
169+ _write_file (source .example , dir = output , muna = muna )
170+ if source .resources :
171+ resource_path = output / "resources"
172+ resource_path .mkdir ()
173+ for res in source .resources :
174+ _write_file (res .url , name = res .name , dir = resource_path , muna = muna )
119175
120176def _load_predictor_func (path : str ) -> Callable [...,object ]:
121177 if "" not in sys .path :
@@ -142,6 +198,27 @@ def _compute_predictor_url(api_url: str, tag: str) -> str:
142198 predictor_url = urlunparse (parsed_url ._replace (netloc = netloc , path = f"{ tag } " ))
143199 return predictor_url
144200
201+ def _get_current_target () -> str :
202+ match (platform .system ().lower (), platform .machine ().lower ()):
203+ case ("darwin" , "arm64" ): return "arm64-apple-darwin"
204+ case ("linux" , "aarch64" ): return "aarch64-unknown-linux-gnu"
205+ case ("linux" , "x86_64" ): return "x86_64-unknown-linux-gnu"
206+ case ("windows" , "arm64" ): return "aarch64-pc-windows-msvc"
207+ case ("windows" , "amd64" ): return "x86_64-pc-windows-msvc"
208+ case (system , arch ): raise ValueError (f"Cannot transpile because your system target is unsupported: { system } { arch } " )
209+
210+ def _write_file (
211+ url : str ,
212+ * ,
213+ name : str = None ,
214+ dir : Path ,
215+ muna : Muna ,
216+ ) -> Path :
217+ name = name or Path (url ).name
218+ path = dir / name
219+ muna .client .download (url , path , progress = True )
220+ return path
221+
145222class _Predictor (BaseModel ):
146223 tag : str
147224
@@ -162,13 +239,20 @@ class _ErrorEvent(BaseModel):
162239 event : Literal ["error" ]
163240 data : _ErrorData
164241
242+ class _TranspiledSource (BaseModel ):
243+ code : str
244+ cmake : str
245+ readme : str
246+ example : str
247+ resources : list [PredictionResource ]
248+
249+ class _SourceEvent (BaseModel ):
250+ event : Literal ["sources" ]
251+ data : list [_TranspiledSource ]
252+
165253class CompileError (Exception ):
166254 pass
167255
168- class _TriagedCompileError (BaseModel ):
169- user : str
170- internal : str
171-
172256class ProgressLogQueue :
173257
174258 def __init__ (self ):
0 commit comments