|
| 1 | +import json |
| 2 | +from typing import Any, Dict, Iterator, List, Optional, Union |
| 3 | + |
1 | 4 | import requests
|
2 |
| -from typing import List, Dict, Any, Optional, Union |
| 5 | + |
3 | 6 | from .exceptions import UnauthorizedError
|
4 | 7 |
|
5 | 8 |
|
@@ -99,3 +102,91 @@ def completions(
|
99 | 102 | if e.response.status_code == 401:
|
100 | 103 | raise UnauthorizedError(e)
|
101 | 104 | raise
|
| 105 | + |
| 106 | + def completions_stream( |
| 107 | + self, |
| 108 | + model: str, |
| 109 | + messages: List[Dict[str, str]], |
| 110 | + temperature: Optional[float] = None, |
| 111 | + top_p: Optional[float] = None, |
| 112 | + n: Optional[int] = None, |
| 113 | + max_tokens: Optional[int] = None, |
| 114 | + presence_penalty: Optional[float] = None, |
| 115 | + frequency_penalty: Optional[float] = None, |
| 116 | + user: Optional[str] = None, |
| 117 | + ) -> Iterator[Dict[str, Any]]: |
| 118 | + """ |
| 119 | + Create a streaming chat completion. |
| 120 | +
|
| 121 | + Args: |
| 122 | + model (str): The model to use for completion |
| 123 | + messages (List[Dict[str, str]]): The messages to generate a completion for |
| 124 | + temperature (Optional[float]): Sampling temperature between 0 and 2 |
| 125 | + top_p (Optional[float]): Nucleus sampling parameter between 0 and 1 |
| 126 | + n (Optional[int]): Number of completions to generate |
| 127 | + max_tokens (Optional[int]): Maximum number of tokens to generate |
| 128 | + presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0 |
| 129 | + frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0 |
| 130 | + user (Optional[str]): Unique identifier for the end user |
| 131 | +
|
| 132 | + Yields: |
| 133 | + Dict[str, Any]: Streaming response chunks from the server |
| 134 | +
|
| 135 | + Raises: |
| 136 | + UnauthorizedError: If the request fails with a 401 status code |
| 137 | + requests.exceptions.RequestException: If the request fails with any other error |
| 138 | + """ |
| 139 | + url = f"{self._base_url}/chat/completions" |
| 140 | + |
| 141 | + # Build request data with required fields |
| 142 | + data: Dict[str, Any] = { |
| 143 | + "model": model, |
| 144 | + "messages": messages, |
| 145 | + "stream": True |
| 146 | + } |
| 147 | + |
| 148 | + # Add optional parameters if provided |
| 149 | + if temperature is not None: |
| 150 | + data["temperature"] = temperature |
| 151 | + if top_p is not None: |
| 152 | + data["top_p"] = top_p |
| 153 | + if n is not None: |
| 154 | + data["n"] = n |
| 155 | + if max_tokens is not None: |
| 156 | + data["max_tokens"] = max_tokens |
| 157 | + if presence_penalty is not None: |
| 158 | + data["presence_penalty"] = presence_penalty |
| 159 | + if frequency_penalty is not None: |
| 160 | + data["frequency_penalty"] = frequency_penalty |
| 161 | + if user is not None: |
| 162 | + data["user"] = user |
| 163 | + |
| 164 | + # Make streaming request |
| 165 | + session = requests.Session() |
| 166 | + try: |
| 167 | + response = session.post( |
| 168 | + url, |
| 169 | + headers=self._get_headers(), |
| 170 | + json=data, |
| 171 | + stream=True |
| 172 | + ) |
| 173 | + response.raise_for_status() |
| 174 | + |
| 175 | + # Parse SSE stream |
| 176 | + for line in response.iter_lines(): |
| 177 | + if line: |
| 178 | + line = line.decode('utf-8') |
| 179 | + if line.startswith('data: '): |
| 180 | + data_str = line[6:] # Remove 'data: ' prefix |
| 181 | + if data_str.strip() == '[DONE]': |
| 182 | + break |
| 183 | + try: |
| 184 | + chunk = json.loads(data_str) |
| 185 | + yield chunk |
| 186 | + except json.JSONDecodeError: |
| 187 | + continue |
| 188 | + |
| 189 | + except requests.exceptions.HTTPError as e: |
| 190 | + if e.response.status_code == 401: |
| 191 | + raise UnauthorizedError(e) |
| 192 | + raise |
0 commit comments