Skip to content

Commit 2044140

Browse files
authored
Iterate over full prediction fragments (#5)
Emit full prediction fragments from the streaming iteration API. Also avoid emitting the async API stability warning from the test suite's model loading and unloading helper comments.
1 parent d7fbb7c commit 2044140

File tree

5 files changed

+67
-63
lines changed

5 files changed

+67
-63
lines changed

src/lmstudio/async_api.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -909,15 +909,15 @@ async def __aexit__(
909909
self._set_error(exc_val)
910910
await self.aclose()
911911

912-
async def __aiter__(self) -> AsyncIterator[str]:
912+
async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
913913
endpoint = self._endpoint
914914
async with self:
915915
assert self._channel is not None
916916
async for contents in self._channel.rx_stream():
917917
for event in endpoint.iter_message_events(contents):
918918
endpoint.handle_rx_event(event)
919919
if isinstance(event, PredictionFragmentEvent):
920-
yield event.arg.content
920+
yield event.arg
921921
if endpoint.is_finished:
922922
break
923923
self._mark_finished()
@@ -1008,8 +1008,8 @@ async def _complete_stream(
10081008
on_prompt_processing_progress,
10091009
)
10101010
channel_cm = self._create_channel(endpoint)
1011-
prediction = AsyncPredictionStream(channel_cm, endpoint)
1012-
return prediction
1011+
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
1012+
return prediction_stream
10131013

10141014
@overload
10151015
async def _respond_stream(
@@ -1064,8 +1064,8 @@ async def _respond_stream(
10641064
on_prompt_processing_progress,
10651065
)
10661066
channel_cm = self._create_channel(endpoint)
1067-
prediction = AsyncPredictionStream(channel_cm, endpoint)
1068-
return prediction
1067+
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
1068+
return prediction_stream
10691069

10701070
async def _apply_prompt_template(
10711071
self,
@@ -1264,7 +1264,7 @@ async def complete(
12641264
on_prompt_processing_progress: Callable[[float], None] | None = None,
12651265
) -> PredictionResult[str] | PredictionResult[DictObject]:
12661266
"""Request a one-off prediction without any context."""
1267-
prediction = await self._session._complete_stream(
1267+
prediction_stream = await self._session._complete_stream(
12681268
self.identifier,
12691269
prompt,
12701270
response_format=response_format,
@@ -1274,11 +1274,11 @@ async def complete(
12741274
on_prediction_fragment=on_prediction_fragment,
12751275
on_prompt_processing_progress=on_prompt_processing_progress,
12761276
)
1277-
async for _ in prediction:
1277+
async for _ in prediction_stream:
12781278
# No yield in body means iterator reliably provides
12791279
# prompt resource cleanup on coroutine cancellation
12801280
pass
1281-
return prediction.result()
1281+
return prediction_stream.result()
12821282

12831283
@overload
12841284
async def respond_stream(
@@ -1365,7 +1365,7 @@ async def respond(
13651365
on_prompt_processing_progress: Callable[[float], None] | None = None,
13661366
) -> PredictionResult[str] | PredictionResult[DictObject]:
13671367
"""Request a response in an ongoing assistant chat session."""
1368-
prediction = await self._session._respond_stream(
1368+
prediction_stream = await self._session._respond_stream(
13691369
self.identifier,
13701370
history,
13711371
response_format=response_format,
@@ -1375,11 +1375,11 @@ async def respond(
13751375
on_prediction_fragment=on_prediction_fragment,
13761376
on_prompt_processing_progress=on_prompt_processing_progress,
13771377
)
1378-
async for _ in prediction:
1378+
async for _ in prediction_stream:
13791379
# No yield in body means iterator reliably provides
13801380
# prompt resource cleanup on coroutine cancellation
13811381
pass
1382-
return prediction.result()
1382+
return prediction_stream.result()
13831383

13841384
@sdk_public_api_async()
13851385
async def apply_prompt_template(
@@ -1411,7 +1411,7 @@ async def embed(
14111411
TAsyncSession = TypeVar("TAsyncSession", bound=AsyncSession)
14121412

14131413
_ASYNC_API_STABILITY_WARNING = """\
1414-
Note: the async API is not yet stable and is expected to change in future releases
1414+
Note the async API is not yet stable and is expected to change in future releases
14151415
"""
14161416

14171417

src/lmstudio/sync_api.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,10 +1062,10 @@ def __exit__(
10621062
self._set_error(exc_val)
10631063
self.close()
10641064

1065-
def __iter__(self) -> Iterator[str]:
1065+
def __iter__(self) -> Iterator[LlmPredictionFragment]:
10661066
for event in self._iter_events():
10671067
if isinstance(event, PredictionFragmentEvent):
1068-
yield event.arg.content
1068+
yield event.arg
10691069

10701070
def _iter_events(self) -> Iterator[PredictionRxEvent]:
10711071
endpoint = self._endpoint
@@ -1165,8 +1165,8 @@ def _complete_stream(
11651165
on_prompt_processing_progress,
11661166
)
11671167
channel_cm = self._create_channel(endpoint)
1168-
prediction = PredictionStream(channel_cm, endpoint)
1169-
return prediction
1168+
prediction_stream = PredictionStream(channel_cm, endpoint)
1169+
return prediction_stream
11701170

11711171
@overload
11721172
def _respond_stream(
@@ -1221,8 +1221,8 @@ def _respond_stream(
12211221
on_prompt_processing_progress,
12221222
)
12231223
channel_cm = self._create_channel(endpoint)
1224-
prediction = PredictionStream(channel_cm, endpoint)
1225-
return prediction
1224+
prediction_stream = PredictionStream(channel_cm, endpoint)
1225+
return prediction_stream
12261226

12271227
def _apply_prompt_template(
12281228
self,
@@ -1419,7 +1419,7 @@ def complete(
14191419
on_prompt_processing_progress: Callable[[float], None] | None = None,
14201420
) -> PredictionResult[str] | PredictionResult[DictObject]:
14211421
"""Request a one-off prediction without any context."""
1422-
prediction = self._session._complete_stream(
1422+
prediction_stream = self._session._complete_stream(
14231423
self.identifier,
14241424
prompt,
14251425
response_format=response_format,
@@ -1429,11 +1429,11 @@ def complete(
14291429
on_prediction_fragment=on_prediction_fragment,
14301430
on_prompt_processing_progress=on_prompt_processing_progress,
14311431
)
1432-
for _ in prediction:
1432+
for _ in prediction_stream:
14331433
# No yield in body means iterator reliably provides
14341434
# prompt resource cleanup on coroutine cancellation
14351435
pass
1436-
return prediction.result()
1436+
return prediction_stream.result()
14371437

14381438
@overload
14391439
def respond_stream(
@@ -1520,7 +1520,7 @@ def respond(
15201520
on_prompt_processing_progress: Callable[[float], None] | None = None,
15211521
) -> PredictionResult[str] | PredictionResult[DictObject]:
15221522
"""Request a response in an ongoing assistant chat session."""
1523-
prediction = self._session._respond_stream(
1523+
prediction_stream = self._session._respond_stream(
15241524
self.identifier,
15251525
history,
15261526
response_format=response_format,
@@ -1530,11 +1530,11 @@ def respond(
15301530
on_prediction_fragment=on_prediction_fragment,
15311531
on_prompt_processing_progress=on_prompt_processing_progress,
15321532
)
1533-
for _ in prediction:
1533+
for _ in prediction_stream:
15341534
# No yield in body means iterator reliably provides
15351535
# prompt resource cleanup on coroutine cancellation
15361536
pass
1537-
return prediction.result()
1537+
return prediction_stream.result()
15381538

15391539
# Multi-round predictions are currently a sync-only handle-only feature
15401540
# TODO: Refactor to allow for more code sharing with the async API

tests/async/test_inference_async.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,17 @@ async def test_complete_stream_async(caplog: LogCap) -> None:
7272
model_id = EXPECTED_LLM_ID
7373
async with AsyncClient() as client:
7474
session = client.llm
75-
prediction = await session._complete_stream(
75+
prediction_stream = await session._complete_stream(
7676
model_id, prompt, config=SHORT_PREDICTION_CONFIG
7777
)
78-
assert isinstance(prediction, AsyncPredictionStream)
78+
assert isinstance(prediction_stream, AsyncPredictionStream)
7979
# Also exercise the explicit context management interface
80-
async with prediction:
81-
async for token in prediction:
82-
logging.info(f"Token: {token}")
83-
assert token
84-
assert isinstance(token, str)
85-
response = prediction.result()
80+
async with prediction_stream:
81+
async for fragment in prediction_stream:
82+
logging.info(f"Fragment: {fragment}")
83+
assert fragment.content
84+
assert isinstance(fragment.content, str)
85+
response = prediction_stream.result()
8686
# The continuation from the LLM will change, but it won't be an empty string
8787
logging.info(f"LLM response: {response!r}")
8888
assert isinstance(response, PredictionResult)
@@ -151,7 +151,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
151151
# This test case also covers the explicit context management interface
152152
iteration_content: list[str] = []
153153
async with prediction_stream:
154-
iteration_content = [text async for text in prediction_stream]
154+
iteration_content = [
155+
fragment.content async for fragment in prediction_stream
156+
]
155157
assert len(messages) == 1
156158
message = messages[0]
157159
assert message.role == "assistant"
@@ -206,7 +208,9 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
206208
# This test case also covers the explicit context management interface
207209
iteration_content: list[str] = []
208210
async with prediction_stream:
209-
iteration_content = [text async for text in prediction_stream]
211+
iteration_content = [
212+
fragment.content async for fragment in prediction_stream
213+
]
210214
assert len(messages) == 1
211215
message = messages[0]
212216
assert message.role == "assistant"
@@ -267,10 +271,10 @@ async def test_invalid_model_request_stream_async(caplog: LogCap) -> None:
267271
# This should error rather than timing out,
268272
# but avoid any risk of the client hanging...
269273
async with asyncio.timeout(30):
270-
prediction = await model.complete_stream("Some text")
271-
async with prediction:
274+
prediction_stream = await model.complete_stream("Some text")
275+
async with prediction_stream:
272276
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
273-
await prediction.wait_for_result()
277+
await prediction_stream.wait_for_result()
274278
check_sdk_error(exc_info, __file__)
275279

276280

@@ -283,11 +287,11 @@ async def test_cancel_prediction_async(caplog: LogCap) -> None:
283287
caplog.set_level(logging.DEBUG)
284288
async with AsyncClient() as client:
285289
session = client.llm
286-
response = await session._complete_stream(model_id, prompt=prompt)
287-
async for _ in response:
288-
await response.cancel()
290+
stream = await session._complete_stream(model_id, prompt=prompt)
291+
async for _ in stream:
292+
await stream.cancel()
289293
num_times += 1
290-
assert response.stats
291-
assert response.stats.stop_reason == "userStopped"
294+
assert stream.stats
295+
assert stream.stats.stop_reason == "userStopped"
292296
# ensure __aiter__ closes correctly
293297
assert num_times == 1

tests/sync/test_inference_sync.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@ def test_complete_stream_sync(caplog: LogCap) -> None:
7676
model_id = EXPECTED_LLM_ID
7777
with Client() as client:
7878
session = client.llm
79-
prediction = session._complete_stream(
79+
prediction_stream = session._complete_stream(
8080
model_id, prompt, config=SHORT_PREDICTION_CONFIG
8181
)
82-
assert isinstance(prediction, PredictionStream)
82+
assert isinstance(prediction_stream, PredictionStream)
8383
# Also exercise the explicit context management interface
84-
with prediction:
85-
for token in prediction:
86-
logging.info(f"Token: {token}")
87-
assert token
88-
assert isinstance(token, str)
89-
response = prediction.result()
84+
with prediction_stream:
85+
for fragment in prediction_stream:
86+
logging.info(f"Fragment: {fragment}")
87+
assert fragment.content
88+
assert isinstance(fragment.content, str)
89+
response = prediction_stream.result()
9090
# The continuation from the LLM will change, but it won't be an empty string
9191
logging.info(f"LLM response: {response!r}")
9292
assert isinstance(response, PredictionResult)
@@ -153,7 +153,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
153153
# This test case also covers the explicit context management interface
154154
iteration_content: list[str] = []
155155
with prediction_stream:
156-
iteration_content = [text for text in prediction_stream]
156+
iteration_content = [fragment.content for fragment in prediction_stream]
157157
assert len(messages) == 1
158158
message = messages[0]
159159
assert message.role == "assistant"
@@ -207,7 +207,7 @@ def record_fragment(fragment: LlmPredictionFragment) -> None:
207207
# This test case also covers the explicit context management interface
208208
iteration_content: list[str] = []
209209
with prediction_stream:
210-
iteration_content = [text for text in prediction_stream]
210+
iteration_content = [fragment.content for fragment in prediction_stream]
211211
assert len(messages) == 1
212212
message = messages[0]
213213
assert message.role == "assistant"
@@ -265,10 +265,10 @@ def test_invalid_model_request_stream_sync(caplog: LogCap) -> None:
265265
# This should error rather than timing out,
266266
# but avoid any risk of the client hanging...
267267
with nullcontext():
268-
prediction = model.complete_stream("Some text")
269-
with prediction:
268+
prediction_stream = model.complete_stream("Some text")
269+
with prediction_stream:
270270
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
271-
prediction.wait_for_result()
271+
prediction_stream.wait_for_result()
272272
check_sdk_error(exc_info, __file__)
273273

274274

@@ -280,11 +280,11 @@ def test_cancel_prediction_sync(caplog: LogCap) -> None:
280280
caplog.set_level(logging.DEBUG)
281281
with Client() as client:
282282
session = client.llm
283-
response = session._complete_stream(model_id, prompt=prompt)
284-
for _ in response:
285-
response.cancel()
283+
stream = session._complete_stream(model_id, prompt=prompt)
284+
for _ in stream:
285+
stream.cancel()
286286
num_times += 1
287-
assert response.stats
288-
assert response.stats.stop_reason == "userStopped"
287+
assert stream.stats
288+
assert stream.stats.stop_reason == "userStopped"
289289
# ensure __aiter__ closes correctly
290290
assert num_times == 1

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ commands =
2424

2525
[testenv:load-test-models]
2626
commands =
27-
python -m tests.load_models
27+
python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.load_models
2828

2929
[testenv:unload-test-models]
3030
commands =
31-
python -m tests.unload_models
31+
python -W "ignore:Note the async API is not yet stable:FutureWarning" -m tests.unload_models
3232

3333
[testenv:coverage]
3434
# Subprocess coverage based on https://hynek.me/articles/turbo-charge-tox/

0 commit comments

Comments
 (0)