Skip to content

Commit ac01ff2

Browse files
committed
Updated LocalLab v0.2.8
1 parent 40c1409 commit ac01ff2

File tree

5 files changed

+90
-20
lines changed

5 files changed

+90
-20
lines changed

CHANGELOG.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,29 @@
22

33
All notable changes to LocalLab will be documented in this file.
44

5+
## [0.2.8] - 2025-03-02
6+
7+
### Fixed
8+
9+
- Fixed parameter mismatch in text generation endpoints by properly handling `max_new_tokens` parameter
10+
- Resolved coroutine awaiting issues in streaming generation endpoints
11+
- Fixed async generator handling in `stream_chat` and `generate_stream` functions
12+
- Enhanced error handling in streaming responses to provide better error messages
13+
- Improved compatibility between route parameters and model manager methods
14+
515
## [0.2.7] - 2025-03-02
616

17+
### Added
18+
19+
- Added missing dependencies in `setup.py`: huggingface_hub, pynvml, and typing_extensions
20+
- Improved dependency management with dev extras for testing packages
21+
- Enhanced error handling for GPU memory detection
22+
723
### Fixed
824

9-
- Added missing dependency `fastapi-cache2` that was causing server startup errors
10-
- Added missing dependency `nvidia-ml-py3` to properly monitor NVIDIA GPUs
11-
- Improved error handling for GPU monitoring when dependencies are missing
25+
- Fixed circular import issues between modules
26+
- Improved error handling in system utilities
27+
- Enhanced compatibility with Google Colab environments
1228

1329
## [0.2.6] - 2025-03-02
1430

locallab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LocalLab: Run LLMs locally with a friendly API similar to OpenAI
33
"""
44

5-
__version__ = "0.2.7"
5+
__version__ = "0.2.8"
66

77
from typing import Dict, Any, Optional
88

locallab/model_manager.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ async def generate(
280280
prompt: str,
281281
stream: bool = False,
282282
max_length: Optional[int] = None,
283+
max_new_tokens: Optional[int] = None,
283284
temperature: Optional[float] = None,
284285
top_p: Optional[float] = None,
285286
top_k: Optional[int] = None,
@@ -307,6 +308,10 @@ async def generate(
307308
from .config import get_model_generation_params
308309
gen_params = get_model_generation_params(self.current_model)
309310

311+
# Handle max_new_tokens parameter (map to max_length)
312+
if max_new_tokens is not None:
313+
max_length = max_new_tokens
314+
310315
# Override with user-provided parameters if specified
311316
if max_length is not None:
312317
try:
@@ -423,8 +428,40 @@ def _stream_generate(
423428
logger.error(f"Streaming generation failed: {str(e)}")
424429
raise HTTPException(status_code=500, detail=f"Streaming generation failed: {str(e)}")
425430

426-
async def async_stream_generate(self, inputs: Dict[str, torch.Tensor], gen_params: Dict[str, Any]):
427-
"""Convert the synchronous stream generator to an async generator."""
431+
async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, gen_params: Dict[str, Any] = None, prompt: str = None, system_prompt: Optional[str] = None, **kwargs):
432+
"""Convert the synchronous stream generator to an async generator.
433+
434+
This can be called either with:
435+
1. inputs and gen_params directly (internal use)
436+
2. prompt, system_prompt and other kwargs (from generate_stream adapter)
437+
"""
438+
# If called with prompt, prepare inputs and parameters
439+
if prompt is not None:
440+
# Get appropriate system instructions
441+
from .config import system_instructions
442+
instructions = str(system_instructions.get_instructions(self.current_model)) if not system_prompt else str(system_prompt)
443+
444+
# Format prompt with system instructions
445+
formatted_prompt = f"""<|system|>{instructions}</|system|>\n<|user|>{prompt}</|user|>\n<|assistant|>"""
446+
447+
# Get model-specific generation parameters
448+
from .config import get_model_generation_params
449+
gen_params = get_model_generation_params(self.current_model)
450+
451+
# Update with provided kwargs
452+
for key, value in kwargs.items():
453+
if key in ["max_length", "temperature", "top_p", "top_k", "repetition_penalty"]:
454+
gen_params[key] = value
455+
elif key == "max_new_tokens":
456+
# Handle the max_new_tokens parameter by mapping to max_length
457+
gen_params["max_length"] = value
458+
459+
# Tokenize the prompt
460+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
461+
for key in inputs:
462+
inputs[key] = inputs[key].to(self.device)
463+
464+
# Now stream tokens using the prepared inputs and parameters
428465
for token in self._stream_generate(inputs, gen_params=gen_params):
429466
yield token
430467
await asyncio.sleep(0)
@@ -564,6 +601,11 @@ async def generate_text(self, prompt: str, system_prompt: Optional[str] = None,
564601
"""
565602
# Make sure we're not streaming when generating text
566603
kwargs["stream"] = False
604+
605+
# Handle max_new_tokens parameter by mapping to max_length if needed
606+
if "max_new_tokens" in kwargs and "max_length" not in kwargs:
607+
kwargs["max_length"] = kwargs.pop("max_new_tokens")
608+
567609
# Directly await the generate method to return the string result
568610
return await self.generate(prompt=prompt, system_instructions=system_prompt, **kwargs)
569611

@@ -572,7 +614,14 @@ async def generate_stream(self, prompt: str, system_prompt: Optional[str] = None
572614
Calls the async_stream_generate method with proper parameters."""
573615
# Ensure streaming is enabled
574616
kwargs["stream"] = True
575-
return self.async_stream_generate(prompt=prompt, system_prompt=system_prompt, **kwargs)
617+
618+
# Handle max_new_tokens parameter by mapping to max_length
619+
if "max_new_tokens" in kwargs and "max_length" not in kwargs:
620+
kwargs["max_length"] = kwargs.pop("max_new_tokens")
621+
622+
# Call async_stream_generate with the prompt and parameters
623+
async for token in self.async_stream_generate(prompt=prompt, system_prompt=system_prompt, **kwargs):
624+
yield token
576625

577626
def is_model_loaded(self, model_id: str) -> bool:
578627
"""Check if a specific model is loaded.

locallab/routes/generate.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
66
from fastapi.responses import JSONResponse, StreamingResponse
77
from pydantic import BaseModel, Field
8-
from typing import Dict, List, Any, Optional, Generator, Tuple
8+
from typing import Dict, List, Any, Optional, Generator, Tuple, AsyncGenerator
99
import json
1010

1111
from ..logger import get_logger
@@ -212,9 +212,9 @@ async def generate_stream(
212212
temperature: float,
213213
top_p: float,
214214
system_prompt: Optional[str]
215-
) -> Generator[str, None, None]:
215+
) -> AsyncGenerator[str, None]:
216216
"""
217-
Generate text in a streaming fashion
217+
Generate text in a streaming fashion and return as server-sent events
218218
"""
219219
try:
220220
# Get model-specific generation parameters
@@ -230,12 +230,15 @@ async def generate_stream(
230230
# Merge model-specific params with request params
231231
generation_params.update(model_params)
232232

233-
# Stream tokens
234-
async for token in model_manager.generate_stream(
233+
# Get the stream generator
234+
stream_generator = model_manager.generate_stream(
235235
prompt=prompt,
236236
system_prompt=system_prompt,
237237
**generation_params
238-
):
238+
)
239+
240+
# Stream tokens
241+
async for token in stream_generator:
239242
# Format as server-sent event
240243
data = token.replace("\n", "\\n")
241244
yield f"data: {data}\n\n"
@@ -252,9 +255,9 @@ async def stream_chat(
252255
max_tokens: int,
253256
temperature: float,
254257
top_p: float
255-
) -> Generator[str, None, None]:
258+
) -> AsyncGenerator[str, None]:
256259
"""
257-
Stream chat completion
260+
Stream chat completion responses as server-sent events
258261
"""
259262
try:
260263
# Get model-specific generation parameters
@@ -270,19 +273,21 @@ async def stream_chat(
270273
# Merge model-specific params with request params
271274
generation_params.update(model_params)
272275

273-
# Generate streaming tokens
274-
async for token in model_manager.generate_stream(
276+
# Generate streaming tokens - properly await the async generator
277+
stream_generator = model_manager.generate_stream(
275278
prompt=formatted_prompt,
276279
**generation_params
277-
):
280+
)
281+
282+
async for token in stream_generator:
278283
# Format as a server-sent event with the structure expected by chat clients
279284
data = json.dumps({"role": "assistant", "content": token})
280285
yield f"data: {data}\n\n"
281286

282287
# End of stream marker
283288
yield "data: [DONE]\n\n"
284289
except Exception as e:
285-
logger.error(f"Chat streaming failed: {str(e)}")
290+
logger.error(f"Streaming generation failed: {str(e)}")
286291
error_data = json.dumps({"error": str(e)})
287292
yield f"data: {error_data}\n\n"
288293
yield "data: [DONE]\n\n"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="locallab",
8-
version="0.2.7",
8+
version="0.2.8",
99
packages=find_packages(include=["locallab", "locallab.*"]),
1010
install_requires=[
1111
"fastapi>=0.95.0,<1.0.0",

0 commit comments

Comments
 (0)