|
3 | 3 | # Copyright © 2026 NatML Inc. All Rights Reserved. |
4 | 4 | # |
5 | 5 |
|
| 6 | +from base64 import b64decode |
6 | 7 | from importlib.util import module_from_spec, spec_from_file_location |
7 | 8 | from inspect import getmembers, getmodulename, isfunction |
| 9 | +from os import environ |
8 | 10 | from pathlib import Path |
9 | 11 | import platform |
10 | 12 | from pydantic import BaseModel |
| 13 | +import re |
| 14 | +from requests import get as http_get |
11 | 15 | from rich import print as print_rich |
12 | 16 | import sys |
| 17 | +from tempfile import TemporaryDirectory |
13 | 18 | from typer import Argument, Option |
14 | 19 | from typing import Annotated, Callable, Literal |
15 | 20 | from urllib.parse import urlparse, urlunparse |
@@ -103,23 +108,25 @@ def compile_function( |
103 | 108 | print_rich(f"\n[bold spring_green3]🎉 Predictor is now being compiled.[/bold spring_green3] Check it out at [link={predictor_url}]{predictor_url}[/link]") |
104 | 109 |
|
105 | 110 | def transpile_function( |
106 | | - path: Annotated[Path, Argument( |
107 | | - resolve_path=True, |
108 | | - exists=True, |
109 | | - readable=True, |
110 | | - file_okay=True, |
111 | | - dir_okay=False, |
112 | | - help="Python source path." |
113 | | - )], |
| 111 | + path: Annotated[ |
| 112 | + str, |
| 113 | + Argument(help="Python source path or GitHub URL.") |
| 114 | + ], |
114 | 115 | output: Annotated[Path, Option( |
115 | 116 | resolve_path=True, |
116 | 117 | exists=False, |
117 | 118 | writable=True, |
118 | 119 | help="Output path for generated C++ sources." |
119 | | - )]=Path("cpp") |
| 120 | + )]=Path("cpp"), |
| 121 | + trust_remote_code: Annotated[bool, Option( |
| 122 | + "--trust-remote-code", |
| 123 | + help="Trust and execute code from remote URLs. Required when using GitHub URLs." |
| 124 | + )]=False, |
120 | 125 | ): |
121 | 126 | muna = Muna(get_access_key()) |
122 | | - # Check path |
| 127 | + # Resolve path |
| 128 | + path: Path = _resolve_source_path(path, trust_remote_code=trust_remote_code) |
| 129 | + # Check output |
123 | 130 | if output.exists(): |
124 | 131 | raise ValueError(f"Cannot transpile because output directory already exists: {output}") |
125 | 132 | with CustomProgress(): |
@@ -222,6 +229,62 @@ def _write_file( |
222 | 229 | muna.client.download(url, path, progress=True) |
223 | 230 | return path |
224 | 231 |
|
| 232 | +def _resolve_source_path( |
| 233 | + path: str, |
| 234 | + *, |
| 235 | + trust_remote_code: bool |
| 236 | +) -> Path: |
| 237 | + # GitHub URL |
| 238 | + github_match = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)", path) |
| 239 | + if github_match: |
| 240 | + if not trust_remote_code: |
| 241 | + raise ValueError( |
| 242 | + "Cannot transpile remote code without explicit trust. " |
| 243 | + "Pass --trust-remote-code to confirm you trust this code." |
| 244 | + ) |
| 245 | + owner, repo, ref, file_path = github_match.groups() |
| 246 | + return _download_github_file(owner, repo, ref, file_path) |
| 247 | + # Local path |
| 248 | + local_path = Path(path).resolve() |
| 249 | + if not local_path.exists(): |
| 250 | + raise ValueError(f"Cannot transpile because no file exists at path: {local_path}") |
| 251 | + if not local_path.is_file(): |
| 252 | + raise ValueError(f"Cannot transpile because path is not a file: {local_path}") |
| 253 | + return local_path |
| 254 | + |
| 255 | +def _download_github_file( |
| 256 | + owner: str, |
| 257 | + repo: str, |
| 258 | + ref: str, |
| 259 | + file_path: str |
| 260 | +) -> Path: |
| 261 | + api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}" |
| 262 | + headers = { "Accept": "application/vnd.github.v3+json" } |
| 263 | + github_token = environ.get("GITHUB_TOKEN") |
| 264 | + if github_token: |
| 265 | + headers["Authorization"] = f"Bearer {github_token}" |
| 266 | + response = http_get(api_url, headers=headers, params={ "ref": ref }) |
| 267 | + if response.status_code == 404: |
| 268 | + raise ValueError( |
| 269 | + f"Cannot transpile because file not found: {file_path} at ref '{ref}' " |
| 270 | + f"in repository {owner}/{repo}. Note: branch names containing '/' are not supported." |
| 271 | + ) |
| 272 | + response.raise_for_status() |
| 273 | + data = response.json() |
| 274 | + # Decode content |
| 275 | + if data.get("encoding") == "base64": |
| 276 | + content = b64decode(data["content"]) |
| 277 | + else: |
| 278 | + download_response = http_get(data["download_url"], headers=headers) |
| 279 | + download_response.raise_for_status() |
| 280 | + content = download_response.content |
| 281 | + # Write to temp file |
| 282 | + with TemporaryDirectory(delete=False) as tmp_dir: |
| 283 | + name = Path(file_path).name |
| 284 | + path = Path(tmp_dir) / name |
| 285 | + path.write_bytes(content) |
| 286 | + return path |
| 287 | + |
225 | 288 | class _Predictor(BaseModel): |
226 | 289 | tag: str |
227 | 290 |
|
|
0 commit comments