Skip to content

Commit a3d6f0f

Browse files
authored
Merge branch 'main' into fix-multimodal
2 parents 1cfaebc + f2abc34 commit a3d6f0f

File tree

15 files changed

+2060
-1843
lines changed

15 files changed

+2060
-1843
lines changed

.github/workflows/python-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
python-version: ['3.9', '3.10', '3.11', '3.12']
1515

1616
steps:
17-
- uses: actions/checkout@v3 # Updated to the latest version
17+
- uses: actions/checkout@v4 # Updated to the latest version
1818
- name: Set up Python ${{ matrix.python-version }}
1919
uses: actions/setup-python@v4 # Updated to the latest version
2020
with:
@@ -37,7 +37,7 @@ jobs:
3737
poetry run pytest
3838
3939
- name: Upload pytest results as an artifact (optional)
40-
uses: actions/upload-artifact@v3 # Updated to the latest version
40+
uses: actions/upload-artifact@v4 # Updated to the latest version
4141
if: always() # Always run this step to ensure test results are saved even if previous steps fail
4242
with:
4343
name: pytest-results

adalflow/CHANGELOG.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1-
## [0.2.7] - 2024-09-23
1+
## [0.2.7] - 2025-01-16
22

3-
### Improved
4-
- Better diagnose report for `Trainer.diagnose`.
5-
- Multi-hop RAG with handling of Cycle.
6-
7-
## [0.2.7] - TO Be Released
83
### Added
94
- `Memory` is completed with `call` and `add_dialog_turn` methods.
105
- Integrated `LanceDB` in the `Retriever`
6+
- Multi-modal (image input and generation) in `OpenAIClient` along with tests.
7+
- `ComponentList` to support a list of components registered in a component. Added `test_componentlist` to test the `ComponentList`.
8+
119
### Improved
10+
- Better diagnose report for `Trainer.diagnose`.
1211
- `BedrockAPIClient` added more details on setup, yet it is still in experimental stage.
1312
- `AzureAPIClient` added more details on setup, yet it is still in experimental stage.
13+
- `Retriever` class:
14+
- Support data id (field).
15+
- `GradComponent`: Support pass-through gradient for the `forward` method.
16+
17+
Optimization
18+
- Aggregated all backward engine prompts in `backward_engine_prompt`.
19+
- Added `TGDData` for the optimizer to support reasoning at proposing new prompt.
20+
- Added `sequential_order` in the `Trainer` to support the sequential training order. Reorganized the trainer code.
1421
## [0.2.6] - 2024-11-25
1522
### Improved
1623
- Add default `max_tokens=512` to the `AnthropicAPIClient` to avoid the error when the user does not provide the `max_tokens` in the prompt.

adalflow/adalflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.6"
1+
__version__ = "0.2.7"
22

33
from adalflow.core.component import Component, fun_to_component
44
from adalflow.core.container import Sequential, ComponentList

adalflow/adalflow/components/model_client/bedrock_client.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""AWS Bedrock ModelClient integration."""
22

3+
import json
34
import os
4-
from typing import Dict, Optional, Any, Callable
5+
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
56
import backoff
67
import logging
78

@@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
2627
r"""When we only need the content of the first message.
2728
It is the default parser for chat completion."""
2829
return completion["output"]["message"]["content"][0]["text"]
29-
return completion["output"]["message"]["content"][0]["text"]
3030

3131

3232
__all__ = [
@@ -117,6 +117,7 @@ def __init__(
117117
self._aws_connection_timeout = aws_connection_timeout
118118
self._aws_read_timeout = aws_read_timeout
119119

120+
self._client = None
120121
self.session = None
121122
self.sync_client = self.init_sync_client()
122123
self.chat_completion_parser = (
@@ -158,16 +159,51 @@ def init_sync_client(self):
158159
def init_async_client(self):
159160
raise NotImplementedError("Async call not implemented yet.")
160161

161-
def parse_chat_completion(self, completion):
162-
log.debug(f"completion: {completion}")
162+
def handle_stream_response(self, stream: dict) -> GeneratorType:
163+
r"""Handle the stream response from bedrock. Yield the chunks.
164+
165+
Args:
166+
stream (dict): The stream response generator from bedrock.
167+
168+
Returns:
169+
GeneratorType: A generator that yields the chunks from bedrock stream.
170+
"""
171+
try:
172+
stream: GeneratorType = stream["stream"]
173+
for chunk in stream:
174+
log.debug(f"Raw chunk: {chunk}")
175+
yield chunk
176+
except Exception as e:
177+
log.debug(f"Error in handle_stream_response: {e}") # Debug print
178+
raise
179+
180+
def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
181+
r"""Parse the completion, and assign it into the raw_response attribute.
182+
183+
If the completion is a stream, it will be handled by the handle_stream_response
184+
method that returns a Generator. Otherwise, the completion will be parsed using
185+
the get_first_message_content method.
186+
187+
Args:
188+
completion (dict): The completion response from bedrock API call.
189+
190+
Returns:
191+
GeneratorOutput: A generator output object with the parsed completion. May
192+
return a generator if the completion is a stream.
193+
"""
163194
try:
164-
data = completion["output"]["message"]["content"][0]["text"]
165-
usage = self.track_completion_usage(completion)
166-
return GeneratorOutput(data=None, usage=usage, raw_response=data)
195+
usage = None
196+
data = self.chat_completion_parser(completion)
197+
if not isinstance(data, GeneratorType):
198+
# Streaming completion usage tracking is not implemented.
199+
usage = self.track_completion_usage(completion)
200+
return GeneratorOutput(
201+
data=None, error=None, raw_response=data, usage=usage
202+
)
167203
except Exception as e:
168-
log.error(f"Error parsing completion: {e}")
204+
log.error(f"Error parsing the completion: {e}")
169205
return GeneratorOutput(
170-
data=None, error=str(e), raw_response=str(completion)
206+
data=None, error=str(e), raw_response=json.dumps(completion)
171207
)
172208

173209
def track_completion_usage(self, completion: Dict) -> CompletionUsage:
@@ -191,6 +227,7 @@ def list_models(self):
191227
print(f" Description: {model['description']}")
192228
print(f" Provider: {model['provider']}")
193229
print("")
230+
194231
except Exception as e:
195232
print(f"Error listing models: {e}")
196233

@@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
222259
bedrock_runtime_exceptions.ModelErrorException,
223260
bedrock_runtime_exceptions.ValidationException,
224261
),
225-
max_time=5,
262+
max_time=2,
226263
)
227-
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
264+
def call(
265+
self,
266+
api_kwargs: Dict = {},
267+
model_type: ModelType = ModelType.UNDEFINED,
268+
) -> dict:
228269
"""
229270
kwargs is the combined input and model_kwargs
230271
"""
231272
if model_type == ModelType.LLM:
232-
return self.sync_client.converse(**api_kwargs)
273+
if "stream" in api_kwargs and api_kwargs.get("stream", False):
274+
log.debug("Streaming call")
275+
api_kwargs.pop(
276+
"stream", None
277+
) # stream is not a valid parameter for bedrock
278+
self.chat_completion_parser = self.handle_stream_response
279+
return self.sync_client.converse_stream(**api_kwargs)
280+
else:
281+
api_kwargs.pop("stream", None)
282+
return self.sync_client.converse(**api_kwargs)
233283
else:
234284
raise ValueError(f"model_type {model_type} is not supported")
235285

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def convert_inputs_to_api_kwargs(
361361

362362
else:
363363
raise ValueError(f"Invalid operation: {operation}")
364+
364365
else:
365366
raise ValueError(f"model_type {self.model_type} is not supported")
366367
return final_model_kwargs
@@ -471,6 +472,7 @@ async def acall(
471472
api_kwargs["image"].close()
472473
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"):
473474
api_kwargs["mask"].close()
475+
474476
else:
475477
raise ValueError(f"model_type {model_type} is not supported")
476478

0 commit comments

Comments
 (0)