Skip to content

Commit 68332db

Browse files
committed
Add @prediction_endpoint decorator
1 parent 23c69d7 commit 68332db

File tree

4 files changed

+279
-1
lines changed

4 files changed

+279
-1
lines changed

muna/beta/remote/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
# Copyright © 2026 NatML Inc. All Rights Reserved.
44
#
55

6+
from .endpoint import get_prediction_request, prediction_endpoint
67
from .prediction import PredictionService
78
from .schema import RemoteAcceleration, RemotePrediction, RemoteValue

muna/beta/remote/endpoint.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#
2+
# Muna
3+
# Copyright © 2026 NatML Inc. All Rights Reserved.
4+
#
5+
6+
from __future__ import annotations
7+
from collections.abc import Callable, Iterator
8+
from contextlib import redirect_stdout, redirect_stderr
9+
from contextvars import ContextVar
10+
from datetime import datetime, timezone
11+
from functools import reduce, wraps
12+
from inspect import signature, Parameter
13+
from io import StringIO
14+
from pydantic import BaseModel
15+
from secrets import choice
16+
from time import perf_counter
17+
from traceback import format_exc
18+
from typing import Callable, ParamSpec, TypeVar
19+
20+
from .remote import _create_remote_value, _parse_remote_value
21+
from .schema import RemoteAcceleration, RemotePrediction, RemoteValue
22+
23+
P = ParamSpec("P")
24+
R = TypeVar("R")
25+
26+
_prediction_request: ContextVar[CreatePredictionInput | None] = ContextVar(
27+
"prediction_request",
28+
default=None
29+
)
30+
31+
def prediction_endpoint(tag: str) -> Callable[
32+
[Callable[P, R]],
33+
Callable[[CreatePredictionInput], RemotePrediction | Iterator[RemotePrediction]]
34+
]:
35+
"""
36+
Wrap a function to handle serving remote prediction requests.
37+
38+
Parameters:
39+
tag (str): Predictor tag.
40+
"""
41+
def decorator(func: Callable[P, R]) -> Callable[[CreatePredictionInput], RemotePrediction | Iterator[RemotePrediction]]:
42+
# Get function signature to determine required parameters
43+
sig = signature(func)
44+
required_params = {
45+
name for name, param in sig.parameters.items()
46+
if param.default is Parameter.empty
47+
and param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
48+
}
49+
# Define wrapper
50+
@wraps(func)
51+
def wrapper(input: CreatePredictionInput) -> RemotePrediction | Iterator[RemotePrediction]:
52+
prediction_id = _create_prediction_id()
53+
stdout_buffer = StringIO()
54+
start_time = perf_counter()
55+
token = _prediction_request.set(input)
56+
try:
57+
# Check args
58+
missing_args = required_params - set(input.inputs.keys())
59+
if missing_args:
60+
arg_name = next(iter(missing_args))
61+
raise ValueError(
62+
f"Failed to create prediction because required "
63+
f"input argument `{arg_name}` was not provided."
64+
)
65+
# Deserialize inputs
66+
kwargs = {
67+
name: _parse_remote_value(value)
68+
for name, value in input.inputs.items()
69+
}
70+
# Invoke function
71+
with redirect_stdout(stdout_buffer), redirect_stderr(stdout_buffer):
72+
result = func(**kwargs)
73+
# Create prediction
74+
created = datetime.now(timezone.utc).isoformat()
75+
if input.stream:
76+
return map(
77+
lambda r: _create_prediction(
78+
id=prediction_id,
79+
tag=tag,
80+
results=r,
81+
start_time=start_time,
82+
logs=stdout_buffer,
83+
created=created
84+
),
85+
result if isinstance(result, Iterator) else iter([result])
86+
)
87+
else:
88+
result = (
89+
reduce(lambda _, x: x, result)
90+
if isinstance(result, Iterator)
91+
else result
92+
)
93+
return _create_prediction(
94+
id=prediction_id,
95+
tag=tag,
96+
results=result,
97+
start_time=start_time,
98+
logs=stdout_buffer,
99+
created=created
100+
)
101+
except Exception:
102+
latency = (perf_counter() - start_time) * 1000 # millis
103+
return RemotePrediction(
104+
id=prediction_id,
105+
tag=tag,
106+
latency=latency,
107+
logs=stdout_buffer.getvalue(),
108+
error=format_exc(),
109+
created=datetime.now(timezone.utc).isoformat()
110+
)
111+
finally:
112+
_prediction_request.reset(token)
113+
# Return
114+
return wrapper
115+
return decorator
116+
117+
def get_prediction_request() -> CreatePredictionInput | None:
118+
"""
119+
Get the current prediction request, or None if not in scope.
120+
"""
121+
return _prediction_request.get()
122+
123+
def _create_prediction(
124+
*,
125+
id: str,
126+
tag: str,
127+
results: object,
128+
start_time: float,
129+
logs: StringIO,
130+
created: str
131+
) -> RemotePrediction:
132+
latency = (perf_counter() - start_time) * 1000 # millis
133+
results = list(results) if isinstance(results, tuple) else [results]
134+
result_values = [_create_remote_value(value) for value in results]
135+
return RemotePrediction(
136+
id=id,
137+
tag=tag,
138+
results=result_values,
139+
latency=latency,
140+
logs=logs.getvalue(),
141+
created=created
142+
)
143+
144+
def _create_prediction_id() -> str:
145+
ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
146+
random = "".join(choice(ALPHABET) for _ in range(21))
147+
return f"pred_{random}"
148+
149+
class CreatePredictionInput(BaseModel):
150+
api_url: str
151+
access_key: str
152+
tag: str
153+
inputs: dict[str, RemoteValue]
154+
acceleration: RemoteAcceleration | str
155+
stream: bool = False

muna/beta/remote/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ class RemotePrediction(Prediction):
2828
"""
2929
Remote prediction.
3030
"""
31-
results: list[RemoteValue] | None = Field(description="Prediction results.")
31+
results: list[RemoteValue] | None = Field(default=None, description="Prediction results.")

test/remote_test.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# Muna
3+
# Copyright © 2026 NatML Inc. All Rights Reserved.
4+
#
5+
6+
from muna.beta.remote import get_prediction_request, prediction_endpoint, RemotePrediction
7+
from muna.beta.remote.remote import _create_remote_value
8+
from muna.beta.remote.endpoint import CreatePredictionInput
9+
from muna.types import Dtype
10+
from typing import Iterator
11+
12+
def test_prediction_endpoint_create_eager():
13+
payload = CreatePredictionInput(
14+
api_url="https://api.muna.ai/v1",
15+
access_key="",
16+
tag="@muna/greeting",
17+
inputs={
18+
"name": _create_remote_value("Yusuf"),
19+
"age": _create_remote_value(67)
20+
},
21+
acceleration="remote_cpu",
22+
stream=False
23+
)
24+
prediction = _prediction_endpoint_eager(payload)
25+
assert isinstance(prediction, RemotePrediction)
26+
assert isinstance(prediction.results, list)
27+
assert prediction.results[0].type == Dtype.string
28+
29+
def test_prediction_endpoint_stream_eager():
30+
@prediction_endpoint(tag="@muna/greeting")
31+
def predict(name: str, age: int) -> str:
32+
return f"Hello {name}! You are {age} years old"
33+
payload = CreatePredictionInput(
34+
api_url="https://api.muna.ai/v1",
35+
access_key="",
36+
tag="@muna/greeting",
37+
inputs={
38+
"name": _create_remote_value("Yusuf"),
39+
"age": _create_remote_value(67)
40+
},
41+
acceleration="remote_cpu",
42+
stream=True
43+
)
44+
stream = _prediction_endpoint_eager(payload)
45+
assert isinstance(stream, Iterator)
46+
assert len(list(stream)) == 1
47+
48+
def test_prediction_endpoint_create_generator():
49+
payload = CreatePredictionInput(
50+
api_url="https://api.muna.ai/v1",
51+
access_key="",
52+
tag="@muna/greeting",
53+
inputs={
54+
"name": _create_remote_value("Yusuf"),
55+
"age": _create_remote_value(67)
56+
},
57+
acceleration="remote_cpu",
58+
stream=False
59+
)
60+
prediction = _prediction_endpoint_generator(payload)
61+
assert isinstance(prediction, RemotePrediction)
62+
assert isinstance(prediction.results, list)
63+
assert prediction.results[0].type == Dtype.string
64+
65+
def test_prediction_endpoint_stream_generator():
66+
payload = CreatePredictionInput(
67+
api_url="https://api.muna.ai/v1",
68+
access_key="",
69+
tag="@muna/greeting",
70+
inputs={
71+
"name": _create_remote_value("Yusuf"),
72+
"age": _create_remote_value(67)
73+
},
74+
acceleration="remote_cpu",
75+
stream=True
76+
)
77+
stream = _prediction_endpoint_generator(payload)
78+
assert isinstance(stream, Iterator)
79+
assert len(list(stream)) == 2
80+
81+
def test_prediction_endpoint_create_missing_input():
82+
payload = CreatePredictionInput(
83+
api_url="https://api.muna.ai/v1",
84+
access_key="",
85+
tag="@muna/greeting",
86+
inputs={ "name": _create_remote_value("Yusuf") },
87+
acceleration="remote_cpu",
88+
stream=False
89+
)
90+
prediction = _prediction_endpoint_eager(payload)
91+
assert isinstance(prediction, RemotePrediction)
92+
assert prediction.error is not None
93+
94+
def test_prediction_request_input():
95+
request = None
96+
@prediction_endpoint(tag="")
97+
def predict(name: str, age: int):
98+
nonlocal request
99+
request = get_prediction_request()
100+
payload = CreatePredictionInput(
101+
api_url="https://api.muna.ai/v1",
102+
access_key="",
103+
tag="",
104+
inputs={
105+
"name": _create_remote_value("Yusuf"),
106+
"age": _create_remote_value(67)
107+
},
108+
acceleration="remote_cpu",
109+
stream=False
110+
)
111+
predict(payload)
112+
assert request is not None
113+
assert get_prediction_request() is None
114+
115+
@prediction_endpoint(tag="@muna/greeting")
116+
def _prediction_endpoint_eager(name: str, age: int) -> str:
117+
return f"Hello {name}! You are {age} years old."
118+
119+
@prediction_endpoint(tag="@muna/greeting")
120+
def _prediction_endpoint_generator(name: str, age: int) -> Iterator[str]:
121+
yield f"Hello {name}!"
122+
yield f"Hello {name}! You are {age} years old."

0 commit comments

Comments
 (0)