Skip to content

Commit f4ecf3c

Browse files
authored
[Feat] Fixes for LiteLLM Proxy CLI to Auth to Gateway (#14836)
* fix: error msg from updating key * fix _create_new_cli_key * fix validate_key_team_change * fix interface for chat * ruff fix * fix auth for keys * get_litellm_gateway_api_key * fix chat * test fix * linting fix * fix mypy * test_validate_key_team_change_with_member_permissions
1 parent e1b3426 commit f4ecf3c

File tree

10 files changed

+503
-335
lines changed

10 files changed

+503
-335
lines changed

litellm/proxy/client/chat.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import json
2+
from typing import Any, Dict, Iterator, List, Optional, Union
3+
14
import requests
2-
from typing import List, Dict, Any, Optional, Union
5+
36
from .exceptions import UnauthorizedError
47

58

@@ -99,3 +102,91 @@ def completions(
99102
if e.response.status_code == 401:
100103
raise UnauthorizedError(e)
101104
raise
105+
106+
def completions_stream(
107+
self,
108+
model: str,
109+
messages: List[Dict[str, str]],
110+
temperature: Optional[float] = None,
111+
top_p: Optional[float] = None,
112+
n: Optional[int] = None,
113+
max_tokens: Optional[int] = None,
114+
presence_penalty: Optional[float] = None,
115+
frequency_penalty: Optional[float] = None,
116+
user: Optional[str] = None,
117+
) -> Iterator[Dict[str, Any]]:
118+
"""
119+
Create a streaming chat completion.
120+
121+
Args:
122+
model (str): The model to use for completion
123+
messages (List[Dict[str, str]]): The messages to generate a completion for
124+
temperature (Optional[float]): Sampling temperature between 0 and 2
125+
top_p (Optional[float]): Nucleus sampling parameter between 0 and 1
126+
n (Optional[int]): Number of completions to generate
127+
max_tokens (Optional[int]): Maximum number of tokens to generate
128+
presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0
129+
frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0
130+
user (Optional[str]): Unique identifier for the end user
131+
132+
Yields:
133+
Dict[str, Any]: Streaming response chunks from the server
134+
135+
Raises:
136+
UnauthorizedError: If the request fails with a 401 status code
137+
requests.exceptions.RequestException: If the request fails with any other error
138+
"""
139+
url = f"{self._base_url}/chat/completions"
140+
141+
# Build request data with required fields
142+
data: Dict[str, Any] = {
143+
"model": model,
144+
"messages": messages,
145+
"stream": True
146+
}
147+
148+
# Add optional parameters if provided
149+
if temperature is not None:
150+
data["temperature"] = temperature
151+
if top_p is not None:
152+
data["top_p"] = top_p
153+
if n is not None:
154+
data["n"] = n
155+
if max_tokens is not None:
156+
data["max_tokens"] = max_tokens
157+
if presence_penalty is not None:
158+
data["presence_penalty"] = presence_penalty
159+
if frequency_penalty is not None:
160+
data["frequency_penalty"] = frequency_penalty
161+
if user is not None:
162+
data["user"] = user
163+
164+
# Make streaming request
165+
session = requests.Session()
166+
try:
167+
response = session.post(
168+
url,
169+
headers=self._get_headers(),
170+
json=data,
171+
stream=True
172+
)
173+
response.raise_for_status()
174+
175+
# Parse SSE stream
176+
for line in response.iter_lines():
177+
if line:
178+
line = line.decode('utf-8')
179+
if line.startswith('data: '):
180+
data_str = line[6:] # Remove 'data: ' prefix
181+
if data_str.strip() == '[DONE]':
182+
break
183+
try:
184+
chunk = json.loads(data_str)
185+
yield chunk
186+
except json.JSONDecodeError:
187+
continue
188+
189+
except requests.exceptions.HTTPError as e:
190+
if e.response.status_code == 401:
191+
raise UnauthorizedError(e)
192+
raise

litellm/proxy/client/cli/commands/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,12 @@ def prompt_team_selection_fallback(teams: List[Dict[str, Any]]) -> Optional[Dict
281281

282282
def update_key_with_team(base_url: str, api_key: str, team_id: str) -> bool:
283283
"""Update the API key to be associated with the selected team"""
284-
284+
from litellm.proxy._types import SpecialModelNames
285285
from litellm.proxy.client import Client
286286

287287
client = Client(base_url=base_url, api_key=api_key)
288288
try:
289-
client.keys.update(key=api_key, team_id=team_id)
289+
client.keys.update(key=api_key, team_id=team_id, models=[SpecialModelNames.all_team_models.value])
290290
click.echo(f"✅ Successfully assigned key to team: {team_id}")
291291
return True
292292
except requests.exceptions.HTTPError as e:

0 commit comments

Comments
 (0)