|
13 | 13 |
|
14 | 14 | from adalflow.core.model_client import ModelClient |
15 | 15 | from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput |
| 16 | +from adalflow.utils import printc |
16 | 17 |
|
17 | 18 | from adalflow.utils.lazy_import import safe_import, OptionalPackages |
18 | 19 |
|
@@ -165,27 +166,20 @@ def init_sync_client(self): |
165 | 166 | def init_async_client(self): |
166 | 167 | raise NotImplementedError("Async call not implemented yet.") |
167 | 168 |
|
168 | | - @staticmethod |
169 | | - def parse_stream_response(completion: dict) -> str: |
170 | | - if "contentBlockDelta" in completion: |
171 | | - if delta_chunk := completion["contentBlockDelta"]["delta"]: |
172 | | - return delta_chunk["text"] |
173 | | - return '' |
174 | | - |
175 | 169 | def handle_stream_response(self, stream: dict) -> GeneratorType: |
176 | 170 | try: |
177 | | - for chunk in stream["stream"]: |
| 171 | + stream: GeneratorType = stream["stream"] |
| 172 | + for chunk in stream: |
178 | 173 | log.debug(f"Raw chunk: {chunk}") |
179 | | - parsed_content = self.parse_stream_response(chunk) |
180 | | - yield parsed_content |
| 174 | + yield chunk |
181 | 175 | except Exception as e: |
182 | 176 | print(f"Error in handle_stream_response: {e}") # Debug print |
183 | 177 | raise |
184 | 178 |
|
185 | 179 | def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": |
186 | 180 | """Parse the completion, and put it into the raw_response.""" |
187 | 181 | try: |
188 | | - data = self.handle_stream_response(completion) |
| 182 | + data = self.chat_completion_parser(completion) |
189 | 183 | return GeneratorOutput( |
190 | 184 | data=None, error=None, raw_response=data |
191 | 185 | ) |
@@ -254,19 +248,19 @@ def call( |
254 | 248 | self, |
255 | 249 | api_kwargs: Dict = {}, |
256 | 250 | model_type: ModelType = ModelType.UNDEFINED, |
257 | | - stream: bool = False |
258 | 251 | ) -> dict: |
259 | 252 | """ |
260 | 253 | kwargs is the combined input and model_kwargs |
261 | 254 | """ |
262 | 255 | if model_type == ModelType.LLM: |
263 | 256 | if "stream" in api_kwargs and api_kwargs.get("stream", False): |
264 | 257 | log.debug("Streaming call") |
| 258 | + printc("Streaming") |
265 | 259 | api_kwargs.pop("stream") # stream is not a valid parameter for bedrock |
266 | 260 | self.chat_completion_parser = self.handle_stream_response |
267 | 261 | return self.sync_client.converse_stream(**api_kwargs) |
268 | 262 | else: |
269 | | - api_kwargs.pop("stream") |
| 263 | + api_kwargs.pop("stream", None) |
270 | 264 | return self.sync_client.converse(**api_kwargs) |
271 | 265 | else: |
272 | 266 | raise ValueError(f"model_type {model_type} is not supported") |
|
0 commit comments